"""Heterograph NN modules"""fromfunctoolsimportpartialimporttorchasthimporttorch.nnasnnfrom...baseimportDGLError__all__=["HeteroGraphConv","HeteroLinear","HeteroEmbedding"]
[docs]classHeteroGraphConv(nn.Module):r"""A generic module for computing convolution on heterogeneous graphs. The heterograph convolution applies sub-modules on their associating relation graphs, which reads the features from source nodes and writes the updated ones to destination nodes. If multiple relations have the same destination node types, their results are aggregated by the specified method. If the relation graph has no edge, the corresponding module will not be called. Pseudo-code: .. code:: outputs = {nty : [] for nty in g.dsttypes} # Apply sub-modules on their associating relation graphs in parallel for relation in g.canonical_etypes: stype, etype, dtype = relation dstdata = relation_submodule(g[relation], ...) outputs[dtype].append(dstdata) # Aggregate the results for each destination node type rsts = {} for ntype, ntype_outputs in outputs.items(): if len(ntype_outputs) != 0: rsts[ntype] = aggregate(ntype_outputs) return rsts Examples -------- Create a heterograph with three types of relations and nodes. >>> import dgl >>> g = dgl.heterograph({ ... ('user', 'follows', 'user') : edges1, ... ('user', 'plays', 'game') : edges2, ... ('store', 'sells', 'game') : edges3}) Create a ``HeteroGraphConv`` that applies different convolution modules to different relations. Note that the modules for ``'follows'`` and ``'plays'`` do not share weights. >>> import dgl.nn.pytorch as dglnn >>> conv = dglnn.HeteroGraphConv({ ... 'follows' : dglnn.GraphConv(...), ... 'plays' : dglnn.GraphConv(...), ... 'sells' : dglnn.SAGEConv(...)}, ... aggregate='sum') Call forward with some ``'user'`` features. This computes new features for both ``'user'`` and ``'game'`` nodes. >>> import torch as th >>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))} >>> h2 = conv(g, h1) >>> print(h2.keys()) dict_keys(['user', 'game']) Call forward with both ``'user'`` and ``'store'`` features. Because both the ``'plays'`` and ``'sells'`` relations will update the ``'game'`` features, their results are aggregated by the specified method (i.e., summation here). >>> f1 = {'user' : ..., 'store' : ...} >>> f2 = conv(g, f1) >>> print(f2.keys()) dict_keys(['user', 'game']) Call forward with some ``'store'`` features. This only computes new features for ``'game'`` nodes. >>> g1 = {'store' : ...} >>> g2 = conv(g, g1) >>> print(g2.keys()) dict_keys(['game']) Call forward with a pair of inputs is allowed and each submodule will also be invoked with a pair of inputs. >>> x_src = {'user' : ..., 'store' : ...} >>> x_dst = {'user' : ..., 'game' : ...} >>> y_dst = conv(g, (x_src, x_dst)) >>> print(y_dst.keys()) dict_keys(['user', 'game']) Parameters ---------- mods : dict[str, nn.Module] Modules associated with every edge types. The forward function of each module must have a `DGLGraph` object as the first argument, and its second argument is either a tensor object representing the node features or a pair of tensor object representing the source and destination node features. aggregate : str, callable, optional Method for aggregating node features generated by different relations. Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'. The 'stack' aggregation is performed along the second dimension, whose order is deterministic. User can also customize the aggregator by providing a callable instance. For example, aggregation by summation is equivalent to the follows: .. code:: def my_agg_func(tensors, dsttype): # tensors: is a list of tensors to aggregate # dsttype: string name of the destination node type for which the # aggregation is performed stacked = torch.stack(tensors, dim=0) return torch.sum(stacked, dim=0) Attributes ---------- mods : dict[str, nn.Module] Modules associated with every edge types. """def__init__(self,mods,aggregate="sum"):super(HeteroGraphConv,self).__init__()self.mod_dict=modsmods={str(k):vfork,vinmods.items()}# Register as child modulesself.mods=nn.ModuleDict(mods)# PyTorch ModuleDict doesn't have get() method, so I have to store two# dictionaries so that I can index with both canonical edge type and# edge type with the get() method.# Do not break if graph has 0-in-degree nodes.# Because there is no general rule to add self-loop for heterograph.for_,vinself.mods.items():set_allow_zero_in_degree_fn=getattr(v,"set_allow_zero_in_degree",None)ifcallable(set_allow_zero_in_degree_fn):set_allow_zero_in_degree_fn(True)ifisinstance(aggregate,str):self.agg_fn=get_aggregate_fn(aggregate)else:self.agg_fn=aggregatedef_get_module(self,etype):mod=self.mod_dict.get(etype,None)ifmodisnotNone:returnmodifisinstance(etype,tuple):# etype is canonical_,etype,_=etypereturnself.mod_dict[etype]raiseKeyError("Cannot find module with edge type %s"%etype)
[docs]defforward(self,g,inputs,mod_args=None,mod_kwargs=None):"""Forward computation Invoke the forward function with each module and aggregate their results. Parameters ---------- g : DGLGraph Graph data. inputs : dict[str, Tensor] or pair of dict[str, Tensor] Input node features. mod_args : dict[str, tuple[any]], optional Extra positional arguments for the sub-modules. mod_kwargs : dict[str, dict[str, any]], optional Extra key-word arguments for the sub-modules. Returns ------- dict[str, Tensor] Output representations for every types of nodes. """ifmod_argsisNone:mod_args={}ifmod_kwargsisNone:mod_kwargs={}outputs={nty:[]forntying.dsttypes}ifisinstance(inputs,tuple)org.is_block:ifisinstance(inputs,tuple):src_inputs,dst_inputs=inputselse:src_inputs=inputsdst_inputs={k:v[:g.number_of_dst_nodes(k)]fork,vininputs.items()}forstype,etype,dtypeing.canonical_etypes:rel_graph=g[stype,etype,dtype]ifstypenotinsrc_inputsordtypenotindst_inputs:continuedstdata=self._get_module((stype,etype,dtype))(rel_graph,(src_inputs[stype],dst_inputs[dtype]),*mod_args.get(etype,()),**mod_kwargs.get(etype,{}))outputs[dtype].append(dstdata)else:forstype,etype,dtypeing.canonical_etypes:rel_graph=g[stype,etype,dtype]ifstypenotininputs:continuedstdata=self._get_module((stype,etype,dtype))(rel_graph,(inputs[stype],inputs[dtype]),*mod_args.get(etype,()),**mod_kwargs.get(etype,{}))outputs[dtype].append(dstdata)rsts={}fornty,alistinoutputs.items():iflen(alist)!=0:rsts[nty]=self.agg_fn(alist,nty)returnrsts
def_max_reduce_func(inputs,dim):returnth.max(inputs,dim=dim)[0]def_min_reduce_func(inputs,dim):returnth.min(inputs,dim=dim)[0]def_sum_reduce_func(inputs,dim):returnth.sum(inputs,dim=dim)def_mean_reduce_func(inputs,dim):returnth.mean(inputs,dim=dim)def_stack_agg_func(inputs,dsttype):# pylint: disable=unused-argumentiflen(inputs)==0:returnNonereturnth.stack(inputs,dim=1)def_agg_func(inputs,dsttype,fn):# pylint: disable=unused-argumentiflen(inputs)==0:returnNonestacked=th.stack(inputs,dim=0)returnfn(stacked,dim=0)defget_aggregate_fn(agg):"""Internal function to get the aggregation function for node data generated from different relations. Parameters ---------- agg : str Method for aggregating node features generated by different relations. Allowed values are 'sum', 'max', 'min', 'mean', 'stack'. Returns ------- callable Aggregator function that takes a list of tensors to aggregate and returns one aggregated tensor. """ifagg=="sum":fn=_sum_reduce_funcelifagg=="max":fn=_max_reduce_funcelifagg=="min":fn=_min_reduce_funcelifagg=="mean":fn=_mean_reduce_funcelifagg=="stack":fn=None# will not be calledelse:raiseDGLError("Invalid cross type aggregator. Must be one of "'"sum", "max", "min", "mean" or "stack". But got "%s"'%agg)ifagg=="stack":return_stack_agg_funcelse:returnpartial(_agg_func,fn=fn)
[docs]classHeteroLinear(nn.Module):"""Apply linear transformations on heterogeneous inputs. Parameters ---------- in_size : dict[key, int] Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings. out_size : int Output feature size. bias : bool, optional If True, learns a bias term. Defaults: ``True``. Examples -------- >>> import dgl >>> import torch >>> from dgl.nn import HeteroLinear >>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3) >>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)} >>> out_feats = layer(in_feats) >>> print(out_feats['user'].shape) torch.Size([2, 3]) >>> print(out_feats[('user', 'follows', 'user')].shape) torch.Size([3, 3]) """def__init__(self,in_size,out_size,bias=True):super(HeteroLinear,self).__init__()self.linears=nn.ModuleDict()fortyp,typ_in_sizeinin_size.items():self.linears[str(typ)]=nn.Linear(typ_in_size,out_size,bias=bias)
[docs]defforward(self,feat):"""Forward function Parameters ---------- feat : dict[key, Tensor] Heterogeneous input features. It maps keys to features. Returns ------- dict[key, Tensor] Transformed features. """out_feat=dict()fortyp,typ_featinfeat.items():out_feat[typ]=self.linears[str(typ)](typ_feat)returnout_feat
[docs]classHeteroEmbedding(nn.Module):"""Create a heterogeneous embedding table. It internally contains multiple ``torch.nn.Embedding`` with different dictionary sizes. Parameters ---------- num_embeddings : dict[key, int] Size of the dictionaries. A key can be a string or a tuple of strings. embedding_dim : int Size of each embedding vector. Examples -------- >>> import dgl >>> import torch >>> from dgl.nn import HeteroEmbedding >>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4) >>> # Get the heterogeneous embedding table >>> embeds = layer.weight >>> print(embeds['user'].shape) torch.Size([2, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([3, 4]) >>> # Get the embeddings for a subset >>> input_ids = {'user': torch.LongTensor([0]), ... ('user', 'follows', 'user'): torch.LongTensor([0, 2])} >>> embeds = layer(input_ids) >>> print(embeds['user'].shape) torch.Size([1, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([2, 4]) """def__init__(self,num_embeddings,embedding_dim):super(HeteroEmbedding,self).__init__()self.embeds=nn.ModuleDict()self.raw_keys=dict()fortyp,typ_num_rowsinnum_embeddings.items():self.embeds[str(typ)]=nn.Embedding(typ_num_rows,embedding_dim)self.raw_keys[str(typ)]=typ@propertydefweight(self):"""Get the heterogeneous embedding table Returns ------- dict[key, Tensor] Heterogeneous embedding table """return{self.raw_keys[typ]:emb.weightfortyp,embinself.embeds.items()}
[docs]defreset_parameters(self):""" Use the xavier method in nn.init module to make the parameters uniformly distributed """fortypinself.embeds.keys():nn.init.xavier_uniform_(self.embeds[typ].weight)
[docs]defforward(self,input_ids):"""Forward function Parameters ---------- input_ids : dict[key, Tensor] The row IDs to retrieve embeddings. It maps a key to key-specific IDs. Returns ------- dict[key, Tensor] The retrieved embeddings. """embeds=dict()fortyp,typ_idsininput_ids.items():embeds[typ]=self.embeds[str(typ)](typ_ids)returnembeds