[docs]classBiasedMHA(nn.Module):r"""Dense Multi-Head Attention Module with Graph Attention Bias. Compute attention between nodes with attention bias obtained from graph structures, as introduced in `Do Transformers Really Perform Bad for Graph Representation? <https://arxiv.org/pdf/2106.05234>`__ .. math:: \text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b) :math:`Q` and :math:`K` are feature representations of nodes. :math:`d` is the corresponding :attr:`feat_size`. :math:`b` is attention bias, which can be additive or multiplicative according to the operator :math:`\circ`. Parameters ---------- feat_size : int Feature size. num_heads : int Number of attention heads, by which :attr:`feat_size` is divisible. bias : bool, optional If True, it uses bias for linear projection. Default: True. attn_bias_type : str, optional The type of attention bias used for modifying attention. Selected from 'add' or 'mul'. Default: 'add'. * 'add' is for additive attention bias. * 'mul' is for multiplicative attention bias. attn_drop : float, optional Dropout probability on attention weights. Defalt: 0.1. Examples -------- >>> import torch as th >>> from dgl.nn import BiasedMHA >>> ndata = th.rand(16, 100, 512) >>> bias = th.rand(16, 100, 100, 8) >>> net = BiasedMHA(feat_size=512, num_heads=8) >>> out = net(ndata, bias) """def__init__(self,feat_size,num_heads,bias=True,attn_bias_type="add",attn_drop=0.1,):super().__init__()self.feat_size=feat_sizeself.num_heads=num_headsself.head_dim=feat_size//num_headsassert(self.head_dim*num_heads==feat_size),"feat_size must be divisible by num_heads"self.scaling=self.head_dim**-0.5self.attn_bias_type=attn_bias_typeself.q_proj=nn.Linear(feat_size,feat_size,bias=bias)self.k_proj=nn.Linear(feat_size,feat_size,bias=bias)self.v_proj=nn.Linear(feat_size,feat_size,bias=bias)self.out_proj=nn.Linear(feat_size,feat_size,bias=bias)self.dropout=nn.Dropout(p=attn_drop)self.reset_parameters()
[docs]defreset_parameters(self):""" Initialize parameters of projection matrices, the same settings as in the original implementation of the paper. """nn.init.xavier_uniform_(self.q_proj.weight,gain=2**-0.5)nn.init.xavier_uniform_(self.k_proj.weight,gain=2**-0.5)nn.init.xavier_uniform_(self.v_proj.weight,gain=2**-0.5)nn.init.xavier_uniform_(self.out_proj.weight)ifself.out_proj.biasisnotNone:nn.init.constant_(self.out_proj.bias,0.0)
[docs]defforward(self,ndata,attn_bias=None,attn_mask=None):"""Forward computation. Parameters ---------- ndata : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the maximum number of nodes. attn_bias : torch.Tensor, optional The attention bias used for attention modification. Shape: (batch_size, N, N, :attr:`num_heads`). attn_mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by `True` values. Shape: (batch_size, N, N). Note: For rows corresponding to unexisting nodes, make sure at least one entry is set to `False` to prevent obtaining NaNs with softmax. Returns ------- y : torch.Tensor The output tensor. Shape: (batch_size, N, :attr:`feat_size`) """q_h=self.q_proj(ndata).transpose(0,1)k_h=self.k_proj(ndata).transpose(0,1)v_h=self.v_proj(ndata).transpose(0,1)bsz,N,_=ndata.shapeq_h=(q_h.reshape(N,bsz*self.num_heads,self.head_dim).transpose(0,1)*self.scaling)k_h=k_h.reshape(N,bsz*self.num_heads,self.head_dim).permute(1,2,0)v_h=v_h.reshape(N,bsz*self.num_heads,self.head_dim).transpose(0,1)attn_weights=(th.bmm(q_h,k_h).transpose(0,2).reshape(N,N,bsz,self.num_heads).transpose(0,2))ifattn_biasisnotNone:ifself.attn_bias_type=="add":attn_weights+=attn_biaselse:attn_weights*=attn_biasifattn_maskisnotNone:attn_weights[attn_mask.to(th.bool)]=float("-inf")attn_weights=F.softmax(attn_weights.transpose(0,2).reshape(N,N,bsz*self.num_heads).transpose(0,2),dim=2,)attn_weights=self.dropout(attn_weights)attn=th.bmm(attn_weights,v_h).transpose(0,1)attn=self.out_proj(attn.reshape(N,bsz,self.feat_size).transpose(0,1))returnattn