"""Laplacian Positional Encoder"""
import torch as th
import torch.nn as nn
[docs]
class LapPosEncoder(nn.Module):
    r"""Laplacian Positional Encoder (LPE), as introduced in
    `GraphGPS: General Powerful Scalable Graph Transformers
    <https://arxiv.org/abs/2205.12454>`__
    This module is a learned laplacian positional encoding module using
    Transformer or DeepSet.
    Parameters
    ----------
    model_type : str
        Encoder model type for LPE, can only be "Transformer" or "DeepSet".
    num_layer : int
        Number of layers in Transformer/DeepSet Encoder.
    k : int
        Number of smallest non-trivial eigenvectors.
    dim : int
        Output size of final laplacian encoding.
    n_head : int, optional
        Number of heads in Transformer Encoder.
        Default : 1.
    batch_norm : bool, optional
        If True, apply batch normalization on raw laplacian positional
        encoding. Default : False.
    num_post_layer : int, optional
        If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after
        pooling. Default : 0.
    Example
    -------
    >>> import dgl
    >>> from dgl import LapPE
    >>> from dgl.nn import LapPosEncoder
    >>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
    >>> g = transform(g)
    >>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec']
    >>> transformer_encoder = LapPosEncoder(
            model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4
        )
    >>> pos_encoding = transformer_encoder(eigvals, eigvecs)
    >>> deepset_encoder = LapPosEncoder(
            model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2
        )
    >>> pos_encoding = deepset_encoder(eigvals, eigvecs)
    """
    def __init__(
        self,
        model_type,
        num_layer,
        k,
        dim,
        n_head=1,
        batch_norm=False,
        num_post_layer=0,
    ):
        super(LapPosEncoder, self).__init__()
        self.model_type = model_type
        self.linear = nn.Linear(2, dim)
        if self.model_type == "Transformer":
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=dim, nhead=n_head, batch_first=True
            )
            self.pe_encoder = nn.TransformerEncoder(
                encoder_layer, num_layers=num_layer
            )
        elif self.model_type == "DeepSet":
            layers = []
            if num_layer == 1:
                layers.append(nn.ReLU())
            else:
                self.linear = nn.Linear(2, 2 * dim)
                layers.append(nn.ReLU())
                for _ in range(num_layer - 2):
                    layers.append(nn.Linear(2 * dim, 2 * dim))
                    layers.append(nn.ReLU())
                layers.append(nn.Linear(2 * dim, dim))
                layers.append(nn.ReLU())
            self.pe_encoder = nn.Sequential(*layers)
        else:
            raise ValueError(
                f"model_type '{model_type}' is not allowed, must be "
                "'Transformer' or 'DeepSet'."
            )
        if batch_norm:
            self.raw_norm = nn.BatchNorm1d(k)
        else:
            self.raw_norm = None
        if num_post_layer > 0:
            layers = []
            if num_post_layer == 1:
                layers.append(nn.Linear(dim, dim))
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Linear(dim, 2 * dim))
                layers.append(nn.ReLU())
                for _ in range(num_post_layer - 2):
                    layers.append(nn.Linear(2 * dim, 2 * dim))
                    layers.append(nn.ReLU())
                layers.append(nn.Linear(2 * dim, dim))
                layers.append(nn.ReLU())
            self.post_mlp = nn.Sequential(*layers)
        else:
            self.post_mlp = None
[docs]
    def forward(self, eigvals, eigvecs):
        r"""
        Parameters
        ----------
        eigvals : Tensor
            Laplacian Eigenvalues of shape :math:`(N, k)`, k different
            eigenvalues repeat N times, can be obtained by using `LaplacianPE`.
        eigvecs : Tensor
            Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by
            using `LaplacianPE`.
        Returns
        -------
        Tensor
            Return the laplacian positional encodings of shape :math:`(N, d)`,
            where :math:`N` is the number of nodes in the input graph,
            :math:`d` is :attr:`dim`.
        """
        pos_encoding = th.cat(
            (eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2
        ).float()
        empty_mask = th.isnan(pos_encoding)
        pos_encoding[empty_mask] = 0
        if self.raw_norm:
            pos_encoding = self.raw_norm(pos_encoding)
        pos_encoding = self.linear(pos_encoding)
        if self.model_type == "Transformer":
            pos_encoding = self.pe_encoder(
                src=pos_encoding, src_key_padding_mask=empty_mask[:, :, 1]
            )
        else:
            pos_encoding = self.pe_encoder(pos_encoding)
        # Remove masked sequences.
        pos_encoding[empty_mask[:, :, 1]] = 0
        # Sum pooling.
        pos_encoding = th.sum(pos_encoding, 1, keepdim=False)
        # MLP post pooling.
        if self.post_mlp:
            pos_encoding = self.post_mlp(pos_encoding)
        return pos_encoding