TypedLinear๏ƒ

class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[source]๏ƒ

Bases: Module

Linear transformation according to types.

For each sample of the input batch xโˆˆX, apply linear transformation xWt, where t is the type of x.

The module supports two regularization methods (basis-decomposition and block-diagonal-decomposition) proposed by โ€œModeling Relational Data with Graph Convolutional Networksโ€

The basis regularization decomposes Wt by:

Wt(l)=โˆ‘b=1Batb(l)Vb(l)

where B is the number of bases, Vb(l) are linearly combined with coefficients atb(l).

The block-diagonal-decomposition regularization decomposes Wt into B block-diagonal matrices. We refer to B as the number of bases:

Wt(l)=โŠ•b=1BQtb(l)

where B is the number of bases, Qtb(l) are block bases with shape R(d(l+1)/B)ร—(dl/B).

Parameters:
  • in_size (int) โ€“ Input feature size.

  • out_size (int) โ€“ Output feature size.

  • num_types (int) โ€“ Total number of types.

  • regularizer (str, optional) โ€“

    Which weight regularizer to use โ€œbasisโ€ or โ€œbddโ€:

    • โ€basisโ€ is short for basis-decomposition.

    • โ€bddโ€ is short for block-diagonal-decomposition.

    Default applies no regularization.

  • num_bases (int, optional) โ€“ Number of bases. Needed when regularizer is specified. Typically smaller than num_types. Default: None.

Examples

No regularization.

>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])

With basis regularization

>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
forward(x, x_type, sorted_by_type=False)[source]๏ƒ

Forward computation.

Parameters:
  • x (torch.Tensor) โ€“ A 2D input tensor. Shape: (N, D1)

  • x_type (torch.Tensor) โ€“ A 1D integer tensor storing the type of the elements in x with one-to-one correspondenc. Shape: (N,)

  • sorted_by_type (bool, optional) โ€“ Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may be faster.

Returns:

y โ€“ The transformed output tensor. Shape: (N, D2)

Return type:

torch.Tensor

reset_parameters()[source]๏ƒ

Reset parameters