JumpingKnowledge
- class dgl.nn.pytorch.utils.JumpingKnowledge(mode='cat', in_feats=None, num_layers=None)[source]
Bases:
Module
The Jumping Knowledge aggregation module from Representation Learning on Graphs with Jumping Knowledge Networks
It aggregates the output representations of multiple GNN layers with
concatenation
or max pooling
or LSTM
with attention scores
obtained from a BiLSTM- Parameters:
mode (str) – The aggregation to apply. It can be ‘cat’, ‘max’, or ‘lstm’, corresponding to the equations above in order.
in_feats (int, optional) – This argument is only required if
mode
is'lstm'
. The output representation size of a single GNN layer. Note that all GNN layers need to have the same output representation size.num_layers (int, optional) – This argument is only required if
mode
is'lstm'
. The number of GNN layers for output aggregation.
Examples
>>> import dgl >>> import torch as th >>> from dgl.nn import JumpingKnowledge
>>> # Output representations of two GNN layers >>> num_nodes = 3 >>> in_feats = 4 >>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]
>>> # Case1 >>> model = JumpingKnowledge() >>> model(feat_list).shape torch.Size([3, 8])
>>> # Case2 >>> model = JumpingKnowledge(mode='max') >>> model(feat_list).shape torch.Size([3, 4])
>>> # Case3 >>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list)) >>> model(feat_list).shape torch.Size([3, 4])