[docs]classNodeEmbedding:# NodeEmbedding"""Class for storing node embeddings. The class is optimized for training large-scale node embeddings. It updates the embedding in a sparse way and can scale to graphs with millions of nodes. It also supports partitioning to multiple GPUs (on a single machine) for more acceleration. It does not support partitioning across machines. Currently, DGL provides two optimizers that work with this NodeEmbedding class: ``SparseAdagrad`` and ``SparseAdam``. The implementation is based on torch.distributed package. It depends on the pytorch default distributed process group to collect multi-process information and uses ``torch.distributed.TCPStore`` to share meta-data information across multiple gpu processes. It use the local address of '127.0.0.1:12346' to initialize the TCPStore. NOTE: The support of NodeEmbedding is experimental. Parameters ---------- num_embeddings : int The number of embeddings. Currently, the number of embeddings has to be the same as the number of nodes. embedding_dim : int The dimension size of embeddings. name : str The name of the embeddings. The name should uniquely identify the embeddings in the system. init_func : callable, optional The function to create the initial data. If the init function is not provided, the values of the embeddings are initialized to zero. device : th.device Device to store the embeddings on. parittion : NDArrayPartition The partition to use to distributed the embeddings between processes. Examples -------- Before launching multiple gpu processes >>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb In each training process >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step() """def__init__(self,num_embeddings,embedding_dim,name,init_func=None,device=None,partition=None,):global_STOREifdeviceisNone:device=th.device("cpu")# Check whether it is multi-gpu training or not.ifth.distributed.is_initialized():rank=th.distributed.get_rank()world_size=th.distributed.get_world_size()else:rank=-1world_size=0self._rank=rankself._world_size=world_sizeself._store=Noneself._comm=Noneself._partition=partitionhost_name="127.0.0.1"port=12346ifrank>=0:# for multi-gpu training, setup a TCPStore for# embeding status synchronization across GPU processesif_STOREisNone:_STORE=th.distributed.TCPStore(host_name,port,world_size,rank==0,timedelta(seconds=10*60),)self._store=_STORE# embeddings is stored in CPU memory.ifth.device(device)==th.device("cpu"):ifrank<=0:emb=create_shared_mem_array(name,(num_embeddings,embedding_dim),th.float32)ifinit_funcisnotNone:emb=init_func(emb)ifrank==0:# the master gpu processfor_inrange(1,world_size):# send embsself._store.set(name,name)elifrank>0:# receiveself._store.wait([name])emb=get_shared_mem_array(name,(num_embeddings,embedding_dim),th.float32)self._tensor=embelse:# embeddings is stored in GPU memory.self._comm=Trueifnotself._partition:# for communication we need a partitionself._partition=NDArrayPartition(num_embeddings,self._world_sizeifself._world_size>0else1,mode="remainder",)# create local tensors for the weightslocal_size=self._partition.local_size(max(self._rank,0))# TODO(dlasalle): support 16-bit/half embeddingsemb=th.empty([local_size,embedding_dim],dtype=th.float32,requires_grad=False,device=device,)ifinit_func:emb=init_func(emb)self._tensor=embself._num_embeddings=num_embeddingsself._embedding_dim=embedding_dimself._name=nameself._optm_state=None# track optimizer stateself._trace=[]# track minibatchdef__call__(self,node_ids,device=th.device("cpu")):""" node_ids : th.tensor Index of the embeddings to collect. device : th.device Target device to put the collected embeddings. """ifnotself._comm:# For embeddings stored on the CPU.emb=self._tensor[node_ids].to(device)else:# For embeddings stored on the GPU.# The following method is designed to perform communication# across multiple GPUs and can handle situations where only one GPU# is present gracefully, a.k.a. self._world_size == 1 or# 0 (when th.distributed.is_initialized() is false).emb=nccl.sparse_all_to_all_pull(node_ids,self._tensor,self._partition)emb=emb.to(device)ifF.is_recording():emb=F.attach_grad(emb)self._trace.append((node_ids.to(device),emb))returnemb@propertydefstore(self):"""Return torch.distributed.TCPStore for meta data sharing across processes. Returns ------- torch.distributed.TCPStore KVStore used for meta data sharing. """returnself._store@propertydefpartition(self):"""Return the partition identifying how the tensor is split across processes. Returns ------- String The mode. """returnself._partition@propertydefrank(self):"""Return rank of current process. Returns ------- int The rank of current process. """returnself._rank@propertydefworld_size(self):"""Return world size of the pytorch distributed training env. Returns ------- int The world size of the pytorch distributed training env. """returnself._world_size@propertydefname(self):"""Return the name of NodeEmbedding. Returns ------- str The name of NodeEmbedding. """returnself._name@propertydefnum_embeddings(self):"""Return the number of embeddings. Returns ------- int The number of embeddings. """returnself._num_embeddings@propertydefembedding_dim(self):"""Return the dimension of embeddings. Returns ------- int The dimension of embeddings. """returnself._embedding_dimdefset_optm_state(self,state):"""Store the optimizer related state tensor. Parameters ---------- state : tuple of torch.Tensor Optimizer related state. """self._optm_state=state@propertydefoptm_state(self):"""Return the optimizer related state tensor. Returns ------- tuple of torch.Tensor The optimizer related state. """returnself._optm_state@propertydeftrace(self):"""Return a trace of the indices of embeddings used in the training step(s). Returns ------- [torch.Tensor] The indices of embeddings used in the training step(s). """returnself._tracedefreset_trace(self):"""Clean up the trace of the indices of embeddings used in the training step(s). """self._trace=[]@propertydefweight(self):"""Return the tensor storing the node embeddings Returns ------- torch.Tensor The tensor storing the node embeddings """returnself._tensordefall_set_embedding(self,values):"""Set the values of the embedding. This method must be called by all processes sharing the embedding with identical tensors for :attr:`values`. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Parameters ---------- values : Tensor The global tensor to pull values from. """ifself._partition:idxs=F.copy_to(self._partition.get_local_indices(max(self._rank,0),ctx=F.context(self._tensor),),F.context(values),)self._tensor[:]=F.copy_to(F.gather_row(values,idxs),ctx=F.context(self._tensor))[:]else:ifself._rank==0:self._tensor[:]=F.copy_to(values,ctx=F.context(self._tensor))[:]ifth.distributed.is_initialized():th.distributed.barrier()def_all_get_tensor(self,shared_name,tensor,shape):"""A helper function to get model-parallel tensors. This method must and only need to be called in multi-GPU DDP training. For now, it's only used in ``all_get_embedding`` and ``_all_get_optm_state``. """# create a shared memory tensorifself._rank==0:# root process creates shared memoryval=create_shared_mem_array(shared_name,shape,tensor.dtype,)self._store.set(shared_name,shared_name)else:self._store.wait([shared_name])val=get_shared_mem_array(shared_name,shape,tensor.dtype,)# need to map indices and slice into existing tensoridxs=self._partition.map_to_global(F.arange(0,tensor.shape[0],ctx=F.context(tensor)),self._rank,).to(val.device)val[idxs]=tensor.to(val.device)self._store.delete_key(shared_name)# wait for all processes to finishth.distributed.barrier()returnvaldefall_get_embedding(self):"""Return a copy of the embedding stored in CPU memory. If this is a multi-processing instance, the tensor will be returned in shared memory. If the embedding is currently stored on multiple GPUs, all processes must call this method in the same order. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Returns ------- torch.Tensor The tensor storing the node embeddings. """ifself._partition:ifself._world_size==0:# non-multiprocessingreturnself._tensor.to(th.device("cpu"))else:returnself._all_get_tensor(f"{self._name}_gather",self._tensor,(self._num_embeddings,self._embedding_dim),)else:# already stored in CPU memoryreturnself._tensordef_all_get_optm_state(self):"""Return a copy of the whole optimizer states stored in CPU memory. If this is a multi-processing instance, the states will be returned in shared memory. If the embedding is currently stored on multiple GPUs, all processes must call this method in the same order. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Returns ------- tuple of torch.Tensor The optimizer states stored in CPU memory. """ifself._partition:ifself._world_size==0:# non-multiprocessingreturntuple(state.to(th.device("cpu"))forstateinself._optm_state)else:returntuple(self._all_get_tensor(f"state_gather_{self._name}_{i}",state,(self._num_embeddings,*state.shape[1:]),)fori,stateinenumerate(self._optm_state))else:# already stored in CPU memoryreturnself._optm_statedef_all_set_optm_state(self,states):"""Set the optimizer states of the embedding. This method must be called by all processes sharing the embedding with identical :attr:`states`. NOTE: This method must be called by all processes sharing the embedding, or it may result in a deadlock. Parameters ---------- states : tuple of torch.Tensor The global states to pull values from. """ifself._partition:idxs=F.copy_to(self._partition.get_local_indices(max(self._rank,0),ctx=F.context(self._tensor)),F.context(states[0]),)forstate,new_stateinzip(self._optm_state,states):state[:]=F.copy_to(F.gather_row(new_state,idxs),ctx=F.context(self._tensor))[:]else:# stored in CPU memoryifself._rank<=0:forstate,new_stateinzip(self._optm_state,states):state[:]=F.copy_to(new_state,ctx=F.context(self._tensor))[:]ifth.distributed.is_initialized():th.distributed.barrier()