"""Various commonly used linear modules"""# pylint: disable= no-member, arguments-differ, invalid-name, W0235importmathimporttorchimporttorch.nnasnnfrom...opsimportgather_mm,segment_mm__all__=["TypedLinear"]
[docs]classTypedLinear(nn.Module):r"""Linear transformation according to types. For each sample of the input batch :math:`x \in X`, apply linear transformation :math:`xW_t`, where :math:`t` is the type of :math:`x`. The module supports two regularization methods (basis-decomposition and block-diagonal-decomposition) proposed by "`Modeling Relational Data with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`__" The basis regularization decomposes :math:`W_t` by: .. math:: W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)} where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined with coefficients :math:`a_{tb}^{(l)}`. The block-diagonal-decomposition regularization decomposes :math:`W_t` into :math:`B` block-diagonal matrices. We refer to :math:`B` as the number of bases: .. math:: W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)} where :math:`B` is the number of bases, :math:`Q_{tb}^{(l)}` are block bases with shape :math:`R^{(d^{(l+1)}/B)\times(d^{l}/B)}`. Parameters ---------- in_size : int Input feature size. out_size : int Output feature size. num_types : int Total number of types. regularizer : str, optional Which weight regularizer to use "basis" or "bdd": - "basis" is short for basis-decomposition. - "bdd" is short for block-diagonal-decomposition. Default applies no regularization. num_bases : int, optional Number of bases. Needed when ``regularizer`` is specified. Typically smaller than ``num_types``. Default: ``None``. Examples -------- No regularization. >>> from dgl.nn import TypedLinear >>> import torch >>> >>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64]) With basis regularization >>> x = torch.randn(100, 32) >>> x_type = torch.randint(0, 5, (100,)) >>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4) >>> y = m(x, x_type) >>> print(y.shape) torch.Size([100, 64]) """def__init__(self,in_size,out_size,num_types,regularizer=None,num_bases=None):super().__init__()self.in_size=in_sizeself.out_size=out_sizeself.num_types=num_typesifregularizerisNone:self.W=nn.Parameter(torch.Tensor(num_types,in_size,out_size))elifregularizer=="basis":ifnum_basesisNone:raiseValueError('Missing "num_bases" for basis regularization.')self.W=nn.Parameter(torch.Tensor(num_bases,in_size,out_size))self.coeff=nn.Parameter(torch.Tensor(num_types,num_bases))self.num_bases=num_baseselifregularizer=="bdd":ifnum_basesisNone:raiseValueError('Missing "num_bases" for bdd regularization.')ifin_size%num_bases!=0orout_size%num_bases!=0:raiseValueError("Input and output sizes must be divisible by num_bases.")self.submat_in=in_size//num_basesself.submat_out=out_size//num_basesself.W=nn.Parameter(torch.Tensor(num_types,num_bases*self.submat_in*self.submat_out))self.num_bases=num_baseselse:raiseValueError(f'Supported regularizer options: "basis", "bdd", but got {regularizer}')self.regularizer=regularizerself.reset_parameters()
[docs]defreset_parameters(self):"""Reset parameters"""withtorch.no_grad():# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_sizeifself.regularizerisNone:nn.init.uniform_(self.W,-1/math.sqrt(self.in_size),1/math.sqrt(self.in_size),)elifself.regularizer=="basis":nn.init.uniform_(self.W,-1/math.sqrt(self.in_size),1/math.sqrt(self.in_size),)nn.init.xavier_uniform_(self.coeff,gain=nn.init.calculate_gain("relu"))elifself.regularizer=="bdd":nn.init.uniform_(self.W,-1/math.sqrt(self.submat_in),1/math.sqrt(self.submat_in),)else:raiseValueError(f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
defget_weight(self):"""Get type-wise weight"""ifself.regularizerisNone:returnself.Welifself.regularizer=="basis":W=self.W.view(self.num_bases,self.in_size*self.out_size)return(self.coeff@W).view(self.num_types,self.in_size,self.out_size)elifself.regularizer=="bdd":returnself.Welse:raiseValueError(f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
[docs]defforward(self,x,x_type,sorted_by_type=False):"""Forward computation. Parameters ---------- x : torch.Tensor A 2D input tensor. Shape: (N, D1) x_type : torch.Tensor A 1D integer tensor storing the type of the elements in ``x`` with one-to-one correspondenc. Shape: (N,) sorted_by_type : bool, optional Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may be faster. Returns ------- y : torch.Tensor The transformed output tensor. Shape: (N, D2) """w=self.get_weight()ifself.regularizer=="bdd":w=w.index_select(0,x_type).view(-1,self.submat_in,self.submat_out)x=x.view(-1,1,self.submat_in)returntorch.bmm(x,w).view(-1,self.out_size)elifsorted_by_type:pos_l=torch.searchsorted(x_type,torch.arange(self.num_types,device=x.device))pos_r=torch.cat([pos_l[1:],torch.tensor([len(x_type)],device=x.device)])seglen=(pos_r-pos_l).cpu()# XXX(minjie): cause device synchronizereturnsegment_mm(x,w,seglen_a=seglen)else:returngather_mm(x,w,idx_b=x_type)