GroupRevResī
- class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]ī
Bases:
Module
Grouped reversible residual connections for GNNs, as introduced in Training Graph Neural Networks with 1000 Layers
It uniformly partitions an input node feature
into groups across the channel dimension. Besides, it makes copies of the input GNN module . In the forward pass, each GNN module only takes the corresponding group of node features.The output node representations
are computed as follows.where
is the input graph, is arbitrary additional input arguments like edge features, and is concatenation.- Parameters:
gnn_module (nn.Module) â GNN module for message passing.
GroupRevRes
will clone the module forgroups
-1 number of times, yieldinggroups
copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features.groups (int, optional) â The number of groups.
Examples
>>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module): ... def __init__(self, feats, dropout=0.2): ... super(GNNLayer, self).__init__() ... # Use BatchNorm and dropout to prevent gradient vanishing ... # In particular if you use a large number of GNN layers ... self.norm = nn.BatchNorm1d(feats) ... self.conv = GraphConv(feats, feats) ... self.dropout = nn.Dropout(dropout) ... ... def forward(self, g, x): ... x = self.norm(x) ... x = self.dropout(x) ... return self.conv(g, x)
>>> num_nodes = 5 >>> num_edges = 20 >>> feats = 32 >>> groups = 2 >>> g = dgl.rand_graph(num_nodes, num_edges) >>> x = torch.randn(num_nodes, feats) >>> conv = GNNLayer(feats // groups) >>> model = GroupRevRes(conv, groups) >>> out = model(g, x)
- forward(g, x, *args)[source]ī
Apply the GNN module with grouped reversible residual connection.
- Parameters:
g (DGLGraph) â The graph.
x (torch.Tensor) â The input feature of shape
, where is size of input feature, is the number of nodes.args â Additional arguments to pass to
gnn_module
.
- Returns:
The output feature of shape
.- Return type:
torch.Tensor