[docs]classRelGraphConv(nn.Module):r"""Relational graph convolution layer from `Modeling Relational Data with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`__ It can be described in as below: .. math:: h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} \sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation :math:`r`. :math:`e_{j,i}` is the normalizer. :math:`\sigma` is an activation function. :math:`W_0` is the self-loop weight. The basis regularization decomposes :math:`W_r` by: .. math:: W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined with coefficients :math:`a_{rb}^{(l)}`. The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` number of block diagonal matrices. We refer :math:`B` as the number of bases. The block regularization decomposes :math:`W_r` by: .. math:: W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)} where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`. Parameters ---------- in_feat : int Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`. out_feat : int Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. num_rels : int Number of relations. regularizer : str, optional Which weight regularizer to use ("basis", "bdd" or ``None``): - "basis" is for basis-decomposition. - "bdd" is for block-diagonal-decomposition. - ``None`` applies no regularization. Default: ``None``. num_bases : int, optional Number of bases. It comes into effect when a regularizer is applied. If ``None``, it uses number of relations (``num_rels``). Default: ``None``. Note that ``in_feat`` and ``out_feat`` must be divisible by ``num_bases`` when applying "bdd" regularizer. bias : bool, optional True if bias is added. Default: ``True``. activation : callable, optional Activation function. Default: ``None``. self_loop : bool, optional True to include self loop message. Default: ``True``. dropout : float, optional Dropout rate. Default: ``0.0`` layer_norm: bool, optional True to add layer norm. Default: ``False`` Examples -------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import RelGraphConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2) >>> etype = th.tensor([0,1,2,0,1,2]) >>> res = conv(g, feat, etype) >>> res tensor([[ 0.3996, -2.3303], [-0.4323, -0.1440], [ 0.3996, -2.3303], [ 2.1046, -2.8654], [-0.4323, -0.1440], [-0.1309, -1.0000]], grad_fn=<AddBackward0>) """def__init__(self,in_feat,out_feat,num_rels,regularizer=None,num_bases=None,bias=True,activation=None,self_loop=True,dropout=0.0,layer_norm=False,):super().__init__()ifregularizerisnotNoneandnum_basesisNone:num_bases=num_relsself.linear_r=TypedLinear(in_feat,out_feat,num_rels,regularizer,num_bases)self.bias=biasself.activation=activationself.self_loop=self_loopself.layer_norm=layer_norm# biasifself.bias:self.h_bias=nn.Parameter(th.Tensor(out_feat))nn.init.zeros_(self.h_bias)# TODO(minjie): consider remove those options in the future to make# the module only about graph convolution.# layer normifself.layer_norm:self.layer_norm_weight=nn.LayerNorm(out_feat,elementwise_affine=True)# weight for self loopifself.self_loop:self.loop_weight=nn.Parameter(th.Tensor(in_feat,out_feat))nn.init.xavier_uniform_(self.loop_weight,gain=nn.init.calculate_gain("relu"))self.dropout=nn.Dropout(dropout)defmessage(self,edges):"""Message function."""m=self.linear_r(edges.src["h"],edges.data["etype"],self.presorted)if"norm"inedges.data:m=m*edges.data["norm"]return{"m":m}
[docs]defforward(self,g,feat,etypes,norm=None,*,presorted=False):"""Forward computation. Parameters ---------- g : DGLGraph The graph. feat : torch.Tensor A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`. etypes : torch.Tensor or list[int] An 1D integer tensor of edge types. Shape: :math:`(|E|,)`. norm : torch.Tensor, optional An 1D tensor of edge norm value. Shape: :math:`(|E|,)`. presorted : bool, optional Whether the edges of the input graph have been sorted by their types. Forward on pre-sorted graph may be faster. Graphs created by :func:`~dgl.to_homogeneous` automatically satisfy the condition. Also see :func:`~dgl.reorder_graph` for sorting edges manually. Returns ------- torch.Tensor New node features. Shape: :math:`(|V|, D_{out})`. """self.presorted=presortedwithg.local_scope():g.srcdata["h"]=featifnormisnotNone:g.edata["norm"]=normg.edata["etype"]=etypes# message passingg.update_all(self.message,fn.sum("m","h"))# apply bias and activationh=g.dstdata["h"]ifself.layer_norm:h=self.layer_norm_weight(h)ifself.bias:h=h+self.h_biasifself.self_loop:h=h+feat[:g.num_dst_nodes()]@self.loop_weightifself.activation:h=self.activation(h)h=self.dropout(h)returnh