Source code for dgl.nn.pytorch.conv.cugraph_gatconv
"""Torch Module for graph attention network layer using the aggregationprimitives in cugraph-ops"""# pylint: disable=no-member, arguments-differ, invalid-name, too-many-argumentsimporttorchfromtorchimportnnfrom.cugraph_baseimportCuGraphBaseConvtry:frompylibcugraphops.pytorchimportSampledCSC,StaticCSCfrompylibcugraphops.pytorch.operatorsimportmha_gat_n2nasGATConvAggHAS_PYLIBCUGRAPHOPS=TrueexceptImportError:HAS_PYLIBCUGRAPHOPS=False
[docs]classCuGraphGATConv(CuGraphBaseConv):r"""Graph attention layer from `Graph Attention Networks <https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation accelerated by cugraph-ops. See :class:`dgl.nn.pytorch.conv.GATConv` for mathematical model. This module depends on :code:`pylibcugraphops` package, which can be installed via :code:`conda install -c nvidia pylibcugraphops=23.04`. :code:`pylibcugraphops` 23.04 requires python 3.8.x or 3.10.x. .. note:: This is an **experimental** feature. Parameters ---------- in_feats : int Input feature size. out_feats : int Output feature size. num_heads : int Number of heads in Multi-Head Attention. feat_drop : float, optional Dropout rate on feature. Defaults: ``0``. negative_slope : float, optional LeakyReLU angle of negative slope. Defaults: ``0.2``. residual : bool, optional If True, use residual connection. Defaults: ``False``. activation : callable activation function/layer or None, optional. If not None, applies an activation function to the updated node features. Default: ``None``. bias : bool, optional If True, learns a bias term. Defaults: ``True``. Examples -------- >>> import dgl >>> import torch >>> from dgl.nn import CuGraphGATConv >>> device = 'cuda' >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device) >>> g = dgl.add_self_loop(g) >>> feat = torch.ones(6, 10).to(device) >>> conv = CuGraphGATConv(10, 2, num_heads=3).to(device) >>> res = conv(g, feat) >>> res tensor([[[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]], [[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]], [[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]], [[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]], [[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]], [[ 0.2340, 1.9226], [ 1.6477, -1.9986], [ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>) """MAX_IN_DEGREE_MFG=200def__init__(self,in_feats,out_feats,num_heads,feat_drop=0.0,negative_slope=0.2,residual=False,activation=None,bias=True,):ifHAS_PYLIBCUGRAPHOPSisFalse:raiseModuleNotFoundError(f"{self.__class__.__name__} requires pylibcugraphops=23.04. "f"Install via `conda install -c nvidia 'pylibcugraphops=23.04'`."f"pylibcugraphops requires Python 3.8 or 3.10.")super().__init__()self.in_feats=in_featsself.out_feats=out_featsself.num_heads=num_headsself.feat_drop=nn.Dropout(feat_drop)self.negative_slope=negative_slopeself.activation=activationself.fc=nn.Linear(in_feats,out_feats*num_heads,bias=False)self.attn_weights=nn.Parameter(torch.Tensor(2*num_heads*out_feats))ifbias:self.bias=nn.Parameter(torch.Tensor(num_heads*out_feats))else:self.register_buffer("bias",None)ifresidual:ifin_feats==out_feats*num_heads:self.res_fc=nn.Identity()else:self.res_fc=nn.Linear(in_feats,out_feats*num_heads,bias=False)else:self.register_buffer("res_fc",None)self.reset_parameters()
[docs]defforward(self,g,feat,max_in_degree=None):r"""Forward computation. Parameters ---------- g : DGLGraph The graph. feat : torch.Tensor Input features of shape :math:`(N, D_{in})`. max_in_degree : int Maximum in-degree of destination nodes. It is only effective when :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When :attr:`g` is generated from a neighbor sampler, the value should be set to the corresponding :attr:`fanout`. If not given, :attr:`max_in_degree` will be calculated on-the-fly. Returns ------- torch.Tensor The output feature of shape :math:`(N, H, D_{out})` where :math:`H` is the number of heads, and :math:`D_{out}` is size of output feature. """offsets,indices,_=g.adj_tensors("csc")ifg.is_block:ifmax_in_degreeisNone:max_in_degree=g.in_degrees().max().item()ifmax_in_degree<self.MAX_IN_DEGREE_MFG:_graph=SampledCSC(offsets,indices,max_in_degree,g.num_src_nodes(),)else:offsets_fg=self.pad_offsets(offsets,g.num_src_nodes()+1)_graph=StaticCSC(offsets_fg,indices)else:_graph=StaticCSC(offsets,indices)feat=self.feat_drop(feat)feat_transformed=self.fc(feat)out=GATConvAgg(feat_transformed,self.attn_weights,_graph,self.num_heads,"LeakyReLU",self.negative_slope,concat_heads=True,)[:g.num_dst_nodes()].view(-1,self.num_heads,self.out_feats)feat_dst=feat[:g.num_dst_nodes()]ifself.res_fcisnotNone:out=out+self.res_fc(feat_dst).view(-1,self.num_heads,self.out_feats)ifself.biasisnotNone:out=out+self.bias.view(-1,self.num_heads,self.out_feats)ifself.activationisnotNone:out=self.activation(out)returnout