"""Torch Module for Principal Neighbourhood Aggregation Convolution Layer"""# pylint: disable= no-member, arguments-differ, invalid-nameimportnumpyasnpimporttorchimporttorch.nnasnndefaggregate_mean(h):"""mean aggregation"""returntorch.mean(h,dim=1)defaggregate_max(h):"""max aggregation"""returntorch.max(h,dim=1)[0]defaggregate_min(h):"""min aggregation"""returntorch.min(h,dim=1)[0]defaggregate_sum(h):"""sum aggregation"""returntorch.sum(h,dim=1)defaggregate_std(h):"""standard deviation aggregation"""returntorch.sqrt(aggregate_var(h)+1e-30)defaggregate_var(h):"""variance aggregation"""h_mean_squares=torch.mean(h*h,dim=1)h_mean=torch.mean(h,dim=1)var=torch.relu(h_mean_squares-h_mean*h_mean)returnvardef_aggregate_moment(h,n):"""moment aggregation: for each node (E[(X-E[X])^n])^{1/n}"""h_mean=torch.mean(h,dim=1,keepdim=True)h_n=torch.mean(torch.pow(h-h_mean,n),dim=1)rooted_h_n=torch.sign(h_n)*torch.pow(torch.abs(h_n)+1e-30,1.0/n)returnrooted_h_ndefaggregate_moment_3(h):"""moment aggregation with n=3"""return_aggregate_moment(h,n=3)defaggregate_moment_4(h):"""moment aggregation with n=4"""return_aggregate_moment(h,n=4)defaggregate_moment_5(h):"""moment aggregation with n=5"""return_aggregate_moment(h,n=5)defscale_identity(h):"""identity scaling (no scaling operation)"""returnhdefscale_amplification(h,D,delta):"""amplification scaling"""returnh*(np.log(D+1)/delta)defscale_attenuation(h,D,delta):"""attenuation scaling"""returnh*(delta/np.log(D+1))AGGREGATORS={"mean":aggregate_mean,"sum":aggregate_sum,"max":aggregate_max,"min":aggregate_min,"std":aggregate_std,"var":aggregate_var,"moment3":aggregate_moment_3,"moment4":aggregate_moment_4,"moment5":aggregate_moment_5,}SCALERS={"identity":scale_identity,"amplification":scale_amplification,"attenuation":scale_attenuation,}classPNAConvTower(nn.Module):"""A single PNA tower in PNA layers"""def__init__(self,in_size,out_size,aggregators,scalers,delta,dropout=0.0,edge_feat_size=0,):super(PNAConvTower,self).__init__()self.in_size=in_sizeself.out_size=out_sizeself.aggregators=aggregatorsself.scalers=scalersself.delta=deltaself.edge_feat_size=edge_feat_sizeself.M=nn.Linear(2*in_size+edge_feat_size,in_size)self.U=nn.Linear((len(aggregators)*len(scalers)+1)*in_size,out_size)self.dropout=nn.Dropout(dropout)self.batchnorm=nn.BatchNorm1d(out_size)defreduce_func(self,nodes):"""reduce function for PNA layer: tensordot of multiple aggregation and scaling operations"""msg=nodes.mailbox["msg"]degree=msg.size(1)h=torch.cat([AGGREGATORS[agg](msg)foragginself.aggregators],dim=1)h=torch.cat([SCALERS[scaler](h,D=degree,delta=self.delta)ifscaler!="identity"elsehforscalerinself.scalers],dim=1,)return{"h_neigh":h}defmessage(self,edges):"""message function for PNA layer"""ifself.edge_feat_size>0:f=torch.cat([edges.src["h"],edges.dst["h"],edges.data["a"]],dim=-1)else:f=torch.cat([edges.src["h"],edges.dst["h"]],dim=-1)return{"msg":self.M(f)}defforward(self,graph,node_feat,edge_feat=None):"""compute the forward pass of a single tower in PNA convolution layer"""# calculate graph normalization factorssnorm_n=torch.cat([torch.ones(N,1).to(node_feat)/NforNingraph.batch_num_nodes()],dim=0,).sqrt()withgraph.local_scope():graph.ndata["h"]=node_featifself.edge_feat_size>0:assertedge_featisnotNone,"Edge features must be provided."graph.edata["a"]=edge_featgraph.update_all(self.message,self.reduce_func)h=self.U(torch.cat([node_feat,graph.ndata["h_neigh"]],dim=-1))h=h*snorm_nreturnself.dropout(self.batchnorm(h))
[docs]classPNAConv(nn.Module):r"""Principal Neighbourhood Aggregation Layer from `Principal Neighbourhood Aggregation for Graph Nets <https://arxiv.org/abs/2004.05718>`__ A PNA layer is composed of multiple PNA towers. Each tower takes as input a split of the input features, and computes the message passing as below. .. math:: h_i^(l+1) = U(h_i^l, \oplus_{(i,j)\in E}M(h_i^l, e_{i,j}, h_j^l)) where :math:`h_i` and :math:`e_{i,j}` are node features and edge features, respectively. :math:`M` and :math:`U` are MLPs, taking the concatenation of input for computing output features. :math:`\oplus` represents the combination of various aggregators and scalers. Aggregators aggregate messages from neighbours and scalers scale the aggregated messages in different ways. :math:`\oplus` concatenates the output features of each combination. The output of multiple towers are concatenated and fed into a linear mixing layer for the final output. Parameters ---------- in_size : int Input feature size; i.e. the size of :math:`h_i^l`. out_size : int Output feature size; i.e. the size of :math:`h_i^{l+1}`. aggregators : list of str List of aggregation function names(each aggregator specifies a way to aggregate messages from neighbours), selected from: * ``mean``: the mean of neighbour messages * ``max``: the maximum of neighbour messages * ``min``: the minimum of neighbour messages * ``std``: the standard deviation of neighbour messages * ``var``: the variance of neighbour messages * ``sum``: the sum of neighbour messages * ``moment3``, ``moment4``, ``moment5``: the normalized moments aggregation :math:`(E[(X-E[X])^n])^{1/n}` scalers: list of str List of scaler function names, selected from: * ``identity``: no scaling * ``amplification``: multiply the aggregated message by :math:`\log(d+1)/\delta`, where :math:`d` is the degree of the node. * ``attenuation``: multiply the aggregated message by :math:`\delta/\log(d+1)` delta: float The degree-related normalization factor computed over the training set, used by scalers for normalization. :math:`E[\log(d+1)]`, where :math:`d` is the degree for each node in the training set. dropout: float, optional The dropout ratio. Default: 0.0. num_towers: int, optional The number of towers used. Default: 1. Note that in_size and out_size must be divisible by num_towers. edge_feat_size: int, optional The edge feature size. Default: 0. residual : bool, optional The bool flag that determines whether to add a residual connection for the output. Default: True. If in_size and out_size of the PNA conv layer are not the same, this flag will be set as False forcibly. Example ------- >>> import dgl >>> import torch as th >>> from dgl.nn import PNAConv >>> >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 10) >>> conv = PNAConv(10, 10, ['mean', 'max', 'sum'], ['identity', 'amplification'], 2.5) >>> ret = conv(g, feat) """def__init__(self,in_size,out_size,aggregators,scalers,delta,dropout=0.0,num_towers=1,edge_feat_size=0,residual=True,):super(PNAConv,self).__init__()self.in_size=in_sizeself.out_size=out_sizeassert(in_size%num_towers==0),"in_size must be divisible by num_towers"assert(out_size%num_towers==0),"out_size must be divisible by num_towers"self.tower_in_size=in_size//num_towersself.tower_out_size=out_size//num_towersself.edge_feat_size=edge_feat_sizeself.residual=residualifself.in_size!=self.out_size:self.residual=Falseself.towers=nn.ModuleList([PNAConvTower(self.tower_in_size,self.tower_out_size,aggregators,scalers,delta,dropout=dropout,edge_feat_size=edge_feat_size,)for_inrange(num_towers)])self.mixing_layer=nn.Sequential(nn.Linear(out_size,out_size),nn.LeakyReLU())
[docs]defforward(self,graph,node_feat,edge_feat=None):r""" Description ----------- Compute PNA layer. Parameters ---------- graph : DGLGraph The graph. node_feat : torch.Tensor The input feature of shape :math:`(N, h_n)`. :math:`N` is the number of nodes, and :math:`h_n` must be the same as in_size. edge_feat : torch.Tensor, optional The edge feature of shape :math:`(M, h_e)`. :math:`M` is the number of edges, and :math:`h_e` must be the same as edge_feat_size. Returns ------- torch.Tensor The output node feature of shape :math:`(N, h_n')` where :math:`h_n'` should be the same as out_size. """h_cat=torch.cat([tower(graph,node_feat[:,ti*self.tower_in_size:(ti+1)*self.tower_in_size,],edge_feat,)forti,towerinenumerate(self.towers)],dim=1,)h_out=self.mixing_layer(h_cat)# add residual connectionifself.residual:h_out=h_out+node_featreturnh_out