"""Torch modules for interaction blocks in SchNet"""# pylint: disable= no-member, arguments-differ, invalid-nameimportnumpyasnpimporttorch.nnasnnfrom....importfunctionasfnclassShiftedSoftplus(nn.Module):r"""Applies the element-wise function: .. math:: \text{SSP}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - \log(\text{shift}) Attributes ---------- beta : int :math:`\beta` value for the mathematical formulation. Default to 1. shift : int :math:`\text{shift}` value for the mathematical formulation. Default to 2. """def__init__(self,beta=1,shift=2,threshold=20):super(ShiftedSoftplus,self).__init__()self.shift=shiftself.softplus=nn.Softplus(beta=beta,threshold=threshold)defforward(self,inputs):""" Description ----------- Applies the activation function. Parameters ---------- inputs : float32 tensor of shape (N, *) * denotes any number of additional dimensions. Returns ------- float32 tensor of shape (N, *) Result of applying the activation function to the input. """returnself.softplus(inputs)-np.log(float(self.shift))
[docs]classCFConv(nn.Module):r"""CFConv from `SchNet: A continuous-filter convolutional neural network for modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__ It combines node and edge features in message passing and updates node representations. .. math:: h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} h_j^{l} \circ W^{(l)}e_ij where :math:`\circ` represents element-wise multiplication and for :math:`\text{SPP}` : .. math:: \text{SSP}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - \log(\text{shift}) Parameters ---------- node_in_feats : int Size for the input node features :math:`h_j^{(l)}`. edge_in_feats : int Size for the input edge features :math:`e_ij`. hidden_feats : int Size for the hidden representations. out_feats : int Size for the output representations :math:`h_j^{(l+1)}`. Example ------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import CFConv >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> nfeat = th.ones(6, 10) >>> efeat = th.ones(6, 5) >>> conv = CFConv(10, 5, 3, 2) >>> res = conv(g, nfeat, efeat) >>> res tensor([[-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1209, -0.2289], [-0.1135, -0.2338], [-0.1209, -0.2289], [-0.1283, -0.2240]], grad_fn=<SubBackward0>) """def__init__(self,node_in_feats,edge_in_feats,hidden_feats,out_feats):super(CFConv,self).__init__()self.project_edge=nn.Sequential(nn.Linear(edge_in_feats,hidden_feats),ShiftedSoftplus(),nn.Linear(hidden_feats,hidden_feats),ShiftedSoftplus(),)self.project_node=nn.Linear(node_in_feats,hidden_feats)self.project_out=nn.Sequential(nn.Linear(hidden_feats,out_feats),ShiftedSoftplus())
[docs]defforward(self,g,node_feats,edge_feats):""" Description ----------- Performs message passing and updates node representations. Parameters ---------- g : DGLGraph The graph. node_feats : torch.Tensor or pair of torch.Tensor The input node features. If a torch.Tensor is given, it represents the input node feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, which is the case for bipartite graph, the pair must contain two tensors of shape :math:`(N_{src}, D_{in_{src}})` and :math:`(N_{dst}, D_{in_{dst}})` separately for the source and destination nodes. edge_feats : torch.Tensor The input edge feature of shape :math:`(E, edge_in_feats)` where :math:`E` is the number of edges. Returns ------- torch.Tensor The output node feature of shape :math:`(N_{out}, out_feats)` where :math:`N_{out}` is the number of destination nodes. """withg.local_scope():ifisinstance(node_feats,tuple):node_feats_src,_=node_featselse:node_feats_src=node_featsg.srcdata["hv"]=self.project_node(node_feats_src)g.edata["he"]=self.project_edge(edge_feats)g.update_all(fn.u_mul_e("hv","he","m"),fn.sum("m","h"))returnself.project_out(g.dstdata["h"])