GroupRevRes

class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]

Bases: Module

Grouped reversible residual connections for GNNs, as introduced in Training Graph Neural Networks with 1000 Layers

It uniformly partitions an input node feature X into C groups X1,X2,⋯,XC across the channel dimension. Besides, it makes C copies of the input GNN module fw1,⋯,fwC. In the forward pass, each GNN module only takes the corresponding group of node features.

The output node representations X′ are computed as follows.

X0′=∑i=2CXiXi′=fwi(Xi−1′,g,U)+Xi,i∈{1,⋯,C}X′=X1′‖â€Ļ‖XC′

where g is the input graph, U is arbitrary additional input arguments like edge features, and ‖ is concatenation.

Parameters:
  • gnn_module (nn.Module) – GNN module for message passing. GroupRevRes will clone the module for groups-1 number of times, yielding groups copies in total. The input and output node representation size need to be the same. Its forward function needs to take a DGLGraph and the associated input node features in order, optionally followed by additional arguments like edge features.

  • groups (int, optional) – The number of groups.

Examples

>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
...     def __init__(self, feats, dropout=0.2):
...         super(GNNLayer, self).__init__()
...         # Use BatchNorm and dropout to prevent gradient vanishing
...         # In particular if you use a large number of GNN layers
...         self.norm = nn.BatchNorm1d(feats)
...         self.conv = GraphConv(feats, feats)
...         self.dropout = nn.Dropout(dropout)
...
...     def forward(self, g, x):
...         x = self.norm(x)
...         x = self.dropout(x)
...         return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
forward(g, x, *args)[source]

Apply the GNN module with grouped reversible residual connection.

Parameters:
  • g (DGLGraph) – The graph.

  • x (torch.Tensor) – The input feature of shape (N,Din), where Din is size of input feature, N is the number of nodes.

  • args – Additional arguments to pass to gnn_module.

Returns:

The output feature of shape (N,Din).

Return type:

torch.Tensor