TWIRLSConvο
- class dgl.nn.pytorch.conv.TWIRLSConv(input_d, output_d, hidden_d, prop_step, num_mlp_before=1, num_mlp_after=1, norm='none', precond=True, alp=0, lam=1, attention=False, tau=0.2, T=-1, p=1, use_eta=False, attn_bef=False, dropout=0.0, attn_dropout=0.0, inp_dropout=0.0)[source]ο
Bases:
ModuleConvolution together with iteratively reweighting least squre from Graph Neural Networks Inspired by Classical Iterative Algorithms
- Parameters:
input_d (int) β Number of input features.
output_d (int) β Number of output features.
hidden_d (int) β Size of hidden layers.
prop_step (int) β Number of propagation steps
num_mlp_before (int) β Number of mlp layers before propagation. Default:
1.num_mlp_after (int) β Number of mlp layers after propagation. Default:
1.norm (str) β The type of norm layers inside mlp layers. Can be
'batch','layer'or'none'. Default:'none'precond (str) β If True, use pre conditioning and unormalized laplacian, else not use pre conditioning and use normalized laplacian. Default:
Truealp (float) β The \(\alpha\) in paper. If equal to \(0\), will be automatically decided based on other hyper prameters. Default:
0.lam (float) β The \(\lambda\) in paper. Default:
1.attention (bool) β If
True, add an attention layer inside propagations. Default:False.tau (float) β The \(\tau\) in paper. Default:
0.2.T (float) β The \(T\) in paper. If < 0, \(T\) will be set to infty. Default:
-1.p (float) β The \(p\) in paper. Default:
1.use_eta (bool) β If
True, add a learnable weight on each dimension in attention. Default:False.attn_bef (bool) β If
True, add another attention layer before propagation. Default:False.dropout (float) β The dropout rate in mlp layers. Default:
0.0.attn_dropout (float) β The dropout rate of attention values. Default:
0.0.inp_dropout (float) β The dropout rate on input features. Default:
0.0.
Note
add_self_loopwill be automatically called before propagation.Example
>>> import dgl >>> from dgl.nn import TWIRLSConv >>> import torch as th
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = TWIRLSConv(10, 2, 128, prop_step = 64) >>> res = conv(g , feat) >>> res.size() torch.Size([6, 2])
- forward(graph, feat)[source]ο
Descriptionο
Run TWIRLS forward.
- param graph:
The graph.
- type graph:
DGLGraph
- param feat:
The initial node features.
- type feat:
torch.Tensor
- returns:
The output feature
- rtype:
torch.Tensor
Note
Input shape: \((N, \text{input_d})\) where \(N\) is the number of nodes.
Output shape: \((N, \text{output_d})\).