EGNNConv๏ƒ

class dgl.nn.pytorch.conv.EGNNConv(in_size, hidden_size, out_size, edge_feat_size=0)[source]๏ƒ

Bases: Module

Equivariant Graph Convolutional Layer from E(n) Equivariant Graph Neural Networks

mij=ฯ•e(hil,hjl,||xilโˆ’xjl||2,aij)xil+1=xil+Cโˆ‘jโˆˆN(i)(xilโˆ’xjl)ฯ•x(mij)mi=โˆ‘jโˆˆN(i)mijhil+1=ฯ•h(hil,mi)

where hi, xi, aij are node features, coordinate features, and edge features respectively. ฯ•e, ฯ•h, and ฯ•x are two-layer MLPs. C is a constant for normalization, computed as 1/|N(i)|.

Parameters:
  • in_size (int) โ€“ Input feature size; i.e. the size of hil.

  • hidden_size (int) โ€“ Hidden feature size; i.e. the size of hidden layer in the two-layer MLPs in ฯ•e,ฯ•x,ฯ•h.

  • out_size (int) โ€“ Output feature size; i.e. the size of hil+1.

  • edge_feat_size (int, optional) โ€“ Edge feature size; i.e. the size of aij. Default: 0.

Example

>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGNNConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> node_feat, coord_feat, edge_feat = th.ones(6, 10), th.ones(6, 3), th.ones(6, 2)
>>> conv = EGNNConv(10, 10, 10, 2)
>>> h, x = conv(g, node_feat, coord_feat, edge_feat)
forward(graph, node_feat, coord_feat, edge_feat=None)[source]๏ƒ

Description๏ƒ

Compute EGNN layer.

param graph:

The graph.

type graph:

DGLGraph

param node_feat:

The input feature of shape (N,hn). N is the number of nodes, and hn must be the same as in_size.

type node_feat:

torch.Tensor

param coord_feat:

The coordinate feature of shape (N,hx). N is the number of nodes, and hx can be any positive integer.

type coord_feat:

torch.Tensor

param edge_feat:

The edge feature of shape (M,he). M is the number of edges, and he must be the same as edge_feat_size.

type edge_feat:

torch.Tensor, optional

returns:
  • node_feat_out (torch.Tensor) โ€“ The output node feature of shape (N,hnโ€ฒ) where hnโ€ฒ is the same as out_size.

  • coord_feat_out (torch.Tensor) โ€“ The output coordinate feature of shape (N,hx) where hx is the same as the input coordinate feature dimension.