SpatialEncoder3dο
- class dgl.nn.pytorch.gt.SpatialEncoder3d(num_kernels, num_heads=1, max_node_type=100)[source]ο
Bases:
Module
3D Spatial Encoder, as introduced in One Transformer Can Understand Both 2D & 3D Molecular Data
This module encodes pair-wise relation between node pair
in the 3D geometric space, according to the Gaussian Basis Kernel function:where
is the number of Gaussian Basis kernels. is the Cartesian coordinate of node . are learnable scaling factors and biases determined by node types. are learnable centers and standard deviations of the Gaussian Basis kernels.- Parameters:
num_kernels (int) β Number of Gaussian Basis Kernels to be applied. Each Gaussian Basis Kernel contains a learnable kernel center and a learnable standard deviation.
num_heads (int, optional) β Number of attention heads if multi-head attention mechanism is applied. Default : 1.
max_node_type (int, optional) β Maximum number of node types. Each node type has a corresponding learnable scaling factor and a bias. Default : 100.
Examples
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder3d
>>> coordinate = th.rand(1, 4, 3) >>> node_type = th.tensor([[1, 0, 2, 1]]) >>> spatial_encoder = SpatialEncoder3d(num_kernels=4, ... num_heads=8, ... max_node_type=3) >>> out = spatial_encoder(coordinate, node_type=node_type) >>> print(out.shape) torch.Size([1, 4, 4, 8])
- forward(coord, node_type=None)[source]ο
- Parameters:
coord (torch.Tensor) β 3D coordinates of nodes in shape
, where is the batch size, : is the maximum number of nodes.node_type (torch.Tensor, optional) β
Node type ids of nodes. Default : None.
If specified,
node_type
should be a tensor in shape . The scaling factors in gaussian kernels of each pair of nodes are determined by their node types.Otherwise,
node_type
will be set to zeros of the same shape by default.
- Returns:
Return attention bias as 3D spatial encoding of shape
, where isnum_heads
.- Return type:
torch.Tensor