TAGConv๏ƒ

class dgl.nn.pytorch.conv.TAGConv(in_feats, out_feats, k=2, bias=True, activation=None)[source]๏ƒ

Bases: Module

Topology Adaptive Graph Convolutional layer from Topology Adaptive Graph Convolutional Networks

HK=โˆ‘k=0K(Dโˆ’1/2ADโˆ’1/2)kXฮ˜k,

where A denotes the adjacency matrix, Dii=โˆ‘j=0Aij its diagonal degree matrix, ฮ˜k denotes the linear weights to sum the results of different hops together.

Parameters:
  • in_feats (int) โ€“ Input feature size. i.e, the number of dimensions of X.

  • out_feats (int) โ€“ Output feature size. i.e, the number of dimensions of HK.

  • k (int, optional) โ€“ Number of hops K. Default: 2.

  • bias (bool, optional) โ€“ If True, adds a learnable bias to the output. Default: True.

  • activation (callable activation function/layer or None, optional) โ€“ If not None, applies an activation function to the updated node features. Default: None.

lin๏ƒ

The learnable linear module.

Type:

torch.Module

Example

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import TAGConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = TAGConv(10, 2, k=2)
>>> res = conv(g, feat)
>>> res
tensor([[ 0.5490, -1.6373],
        [ 0.5490, -1.6373],
        [ 0.5490, -1.6373],
        [ 0.5513, -1.8208],
        [ 0.5215, -1.6044],
        [ 0.3304, -1.9927]], grad_fn=<AddmmBackward>)
forward(graph, feat, edge_weight=None)[source]๏ƒ

Description๏ƒ

Compute topology adaptive graph convolution.

param graph:

The graph.

type graph:

DGLGraph

param feat:

The input feature of shape (N,Din) where Din is size of input feature, N is the number of nodes.

type feat:

torch.Tensor

param edge_weight:

edge_weight to use in the message passing process. This is equivalent to using weighted adjacency matrix in the equation above, and D~โˆ’1/2A~D~โˆ’1/2 is based on dgl.nn.pytorch.conv.graphconv.EdgeWeightNorm.

type edge_weight:

torch.Tensor, optional

returns:

The output feature of shape (N,Dout) where Dout is size of output feature.

rtype:

torch.Tensor

reset_parameters()[source]๏ƒ

Description๏ƒ

Reinitialize learnable parameters.

Note

The model parameters are initialized using Glorot uniform initialization.