SpatialEncoder3d

class dgl.nn.pytorch.gt.SpatialEncoder3d(num_kernels, num_heads=1, max_node_type=100)[source]

Bases: Module

3D Spatial Encoder, as introduced in One Transformer Can Understand Both 2D & 3D Molecular Data

This module encodes pair-wise relation between node pair (i,j) in the 3D geometric space, according to the Gaussian Basis Kernel function:

ψ(i,j)k=12Ο€|Οƒk|exp⁑(βˆ’12(Ξ³(i,j)||riβˆ’rj||+Ξ²(i,j)βˆ’ΞΌk|Οƒk|)2),k=1,...,K,

where K is the number of Gaussian Basis kernels. ri is the Cartesian coordinate of node i. Ξ³(i,j),Ξ²(i,j) are learnable scaling factors and biases determined by node types. ΞΌk,Οƒk are learnable centers and standard deviations of the Gaussian Basis kernels.

Parameters:
  • num_kernels (int) – Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable standard deviation.

  • num_heads (int, optional) – Number of attention heads if multi-head attention mechanism is applied. Default : 1.

  • max_node_type (int, optional) – Maximum number of node types. Each node type has a corresponding learnable scaling factor and a bias. Default : 100.

Examples

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder3d
>>> coordinate = th.rand(1, 4, 3)
>>> node_type = th.tensor([[1, 0, 2, 1]])
>>> spatial_encoder = SpatialEncoder3d(num_kernels=4,
...                                    num_heads=8,
...                                    max_node_type=3)
>>> out = spatial_encoder(coordinate, node_type=node_type)
>>> print(out.shape)
torch.Size([1, 4, 4, 8])
forward(coord, node_type=None)[source]
Parameters:
  • coord (torch.Tensor) – 3D coordinates of nodes in shape (B,N,3), where B is the batch size, N: is the maximum number of nodes.

  • node_type (torch.Tensor, optional) –

    Node type ids of nodes. Default : None.

    • If specified, node_type should be a tensor in shape (B,N,). The scaling factors in gaussian kernels of each pair of nodes are determined by their node types.

    • Otherwise, node_type will be set to zeros of the same shape by default.

Returns:

Return attention bias as 3D spatial encoding of shape (B,N,N,H), where H is num_heads.

Return type:

torch.Tensor