EGATConv

class dgl.nn.pytorch.conv.EGATConv(in_node_feats, in_edge_feats, out_node_feats, out_edge_feats, num_heads, bias=True)[source]

Bases: Module

Graph attention layer that handles edge features from Rossmann-Toolbox (see supplementary data)

The difference lies in how unnormalized attention scores eij are obtained:

eij=F→(fij′)fij′=LeakyReLU(A[hi‖fij‖hj])

where fij′ are edge features, A is weight matrix and F→ is weight vector. After that, resulting node features hi′ are updated in the same way as in regular GAT.

Parameters:
  • in_node_feats (int, or pair of ints) – Input feature size; i.e, the number of dimensions of hi. EGATConv can be applied on homogeneous graph and unidirectional bipartite graph. If the layer is to be applied to a unidirectional bipartite graph, in_feats specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value.

  • in_edge_feats (int) – Input edge feature size fij.

  • out_node_feats (int) – Output node feature size.

  • out_edge_feats (int) – Output edge feature size fijβ€².

  • num_heads (int) – Number of attention heads.

  • bias (bool, optional) – If True, add bias term to fijβ€². Defaults: True.

Examples

>>> import dgl
>>> import torch as th
>>> from dgl.nn import EGATConv
>>> # Case 1: Homogeneous graph
>>> num_nodes, num_edges = 8, 30
>>> # generate a graph
>>> graph = dgl.rand_graph(num_nodes,num_edges)
>>> node_feats = th.rand((num_nodes, 20))
>>> edge_feats = th.rand((num_edges, 12))
>>> egat = EGATConv(in_node_feats=20,
...                 in_edge_feats=12,
...                 out_node_feats=15,
...                 out_edge_feats=10,
...                 num_heads=3)
>>> #forward pass
>>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
>>> new_node_feats.shape, new_edge_feats.shape
torch.Size([8, 3, 15]) torch.Size([30, 3, 10])
>>> # Case 2: Unidirectional bipartite graph
>>> u = [0, 1, 0, 0, 1]
>>> v = [0, 1, 2, 3, 2]
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
>>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32))
>>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32))
>>> nfeats = (u_feat,v_feat)
>>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32))
>>> in_node_feats = (25,30)
>>> in_edge_feats = 15
>>> out_node_feats = 10
>>> out_edge_feats = 5
>>> num_heads = 3
>>> egat_model =  EGATConv(in_node_feats,
...                        in_edge_feats,
...                        out_node_feats,
...                        out_edge_feats,
...                        num_heads,
...                        bias=True)
>>> #forward pass
>>> new_node_feats,
>>> new_edge_feats,
>>> attentions = egat_model(g, nfeats, efeats, get_attention=True)
>>> new_node_feats.shape, new_edge_feats.shape, attentions.shape
(torch.Size([4, 3, 10]), torch.Size([5, 3, 5]), torch.Size([5, 3, 1]))
forward(graph, nfeats, efeats, edge_weight=None, get_attention=False)[source]

Compute new node and edge features.

Parameters:
  • graph (DGLGraph) – The graph.

  • nfeat (torch.Tensor or pair of torch.Tensor) –

    If a torch.Tensor is given, the input feature of shape (N,Din) where:

    Din is size of input node feature, N is the number of nodes.

    If a pair of torch.Tensor is given, the pair must contain two tensors of shape

    (Nin,Dinsrc) and (Nout,Dindst).

  • efeats (torch.Tensor) –

    The input edge feature of shape (E,Fin) where:

    Fin is size of input node feature, E is the number of edges.

  • edge_weight (torch.Tensor, optional) – A 1D tensor of edge weight values. Shape: (|E|,).

  • get_attention (bool, optional) – Whether to return the attention values. Default to False.

Returns:

  • pair of torch.Tensor – node output features followed by edge output features. The node output feature is of shape (N,H,Dout) The edge output feature is of shape (F,H,Fout) where:

    H is the number of heads, Dout is size of output node feature, Fout is size of output edge feature.

  • torch.Tensor, optional – The attention values of shape (E,H,1). This is returned only when get_attention is True.

reset_parameters()[source]

Reinitialize learnable parameters.