JumpingKnowledge

class dgl.nn.pytorch.utils.JumpingKnowledge(mode='cat', in_feats=None, num_layers=None)[source]

Bases: Module

The Jumping Knowledge aggregation module from Representation Learning on Graphs with Jumping Knowledge Networks

It aggregates the output representations of multiple GNN layers with

concatenation

hi(1)hi(T)

or max pooling

max(hi(1),,hi(T))

or LSTM

t=1Tαi(t)hi(t)

with attention scores αi(t) obtained from a BiLSTM

Parameters:
  • mode (str) – The aggregation to apply. It can be ‘cat’, ‘max’, or ‘lstm’, corresponding to the equations above in order.

  • in_feats (int, optional) – This argument is only required if mode is 'lstm'. The output representation size of a single GNN layer. Note that all GNN layers need to have the same output representation size.

  • num_layers (int, optional) – This argument is only required if mode is 'lstm'. The number of GNN layers for output aggregation.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers
>>> num_nodes = 3
>>> in_feats = 4
>>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1
>>> model = JumpingKnowledge()
>>> model(feat_list).shape
torch.Size([3, 8])
>>> # Case2
>>> model = JumpingKnowledge(mode='max')
>>> model(feat_list).shape
torch.Size([3, 4])
>>> # Case3
>>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
>>> model(feat_list).shape
torch.Size([3, 4])
forward(feat_list)[source]

Description

Aggregate output representations across multiple GNN layers.

param feat_list:

feat_list[i] is the output representations of a GNN layer.

type feat_list:

list[Tensor]

returns:

The aggregated representations.

rtype:

Tensor

reset_parameters()[source]

Description

Reinitialize learnable parameters. This comes into effect only for the lstm mode.