Source code for dgl.distributed.optim.pytorch.sparse_optim
"""Node embedding optimizers for distributed training"""importabcimportwarningsfromabcimportabstractmethodfromos.pathimportexistsimporttorchasthimportdglfrom....importbackendasFfrom...dist_tensorimportDistTensorfrom...graph_partition_bookimportEDGE_PART_POLICY,NODE_PART_POLICYfrom...nn.pytorchimportDistEmbeddingfrom.utilsimportalltoall,alltoallvEMB_STATES="emb_states"WORLD_SIZE="world_size"IDS="ids"PARAMS="params"STATES="states"classDistSparseGradOptimizer(abc.ABC):r"""The abstract dist sparse optimizer. Note: dgl dist sparse optimizer only work with dgl.distributed.DistEmbedding Parameters ---------- params : list of DistEmbedding The list of DistEmbedding. lr : float The learning rate. """def__init__(self,params,lr):self._params=paramsself._lr=lrself._rank=Noneself._world_size=Noneself._shared_cache={}self._clean_grad=Falseself._opt_meta={}self._state={}## collect all hyper parameters for saveself._defaults={}ifth.distributed.is_initialized():self._rank=th.distributed.get_rank()self._world_size=th.distributed.get_world_size()else:self._rank=0self._world_size=1deflocal_state_dict(self):"""Return the state pertaining to current rank of the optimizer. Returns ------- dict Local state dict Example Dict of Adagrad Optimizer: .. code-block:: json { "params": { "_lr": 0.01, "_eps": "1e-8", "world_size": 2 }, "emb_states": { "emb_name1": { "ids": [0, 2, 4, 6 ,8 ,10], ## tensor, "emb_name1_sum": [0.1 , 0.2, 0.5, 0.1, 0.2] ## tensor, }, "emb_name2": { "ids": [0, 2, 4, 6 ,8 ,10], ## tensor, "emb_name2_sum": [0.3 , 0.2, 0.4, 0.5, 0.2] ## tensor, } } } :param json: json object See Also -------- load_local_state_dict """local_state_dict={}local_state_dict[EMB_STATES]={}local_state_dict[PARAMS]={WORLD_SIZE:self._world_size}forembinself._params:trainers_per_machine=self._world_size//max(1,dgl.distributed.get_num_machines())emb_state_dict={}part_policy=(emb.part_policyifemb.part_policyelseemb.weight.part_policy)idx=self._get_local_ids(part_policy)iftrainers_per_machine>1:kv_idx_split=(idx%trainers_per_machine).long()local_rank=self._rank%trainers_per_machinemask=kv_idx_split==local_rankidx=F.boolean_mask(idx,mask)emb_state_dict.update({IDS:idx})emb_state={}states=(list(self._state[emb.name])ifisinstance(self._state[emb.name],tuple)else[self._state[emb.name]])emb_state={state.name:state[idx]forstateinstates}emb_state_dict.update({STATES:emb_state})local_state_dict[EMB_STATES].update({emb.name:emb_state_dict})local_state_dict[PARAMS].update(self._defaults)returnlocal_state_dictdefload_local_state_dict(self,local_state_dict):"""Load the local state from the input state_dict, updating the optimizer as needed. Parameters ---------- local_state_dict : dict Optimizer state; should be an object returned from a call to local_state_dict(). See Also -------- local_state_dict """foremb_name,emb_stateinlocal_state_dict[EMB_STATES].items():idx=emb_state[IDS]# As state of an embedding of different optimizers can be a single# DistTensor(Adagrad) or a tuple(Adam) of that, converting it to list for# consistency. The list contains reference(s) to original DistTensor(s).states=(list(self._state[emb_name])ifisinstance(self._state[emb_name],tuple)else[self._state[emb_name]])iflen(emb_state[STATES])!=len(states):raiseValueError(f"loaded state dict has a different number of states"f" of embedding {emb_name}")name_to_index={state.name:indexforindex,stateinenumerate(states)}forname,stateinemb_state[STATES].items():ifnamenotinname_to_index:raiseValueError("loaded state dict contains a state {name}""that can't be found in the optimizer states")state_idx=name_to_index[name]state=state.to(th.device("cpu"),states[name_to_index[name]].dtype)states[state_idx][idx]=stateself._defaults.update(local_state_dict[PARAMS])self.__dict__.update(local_state_dict[PARAMS])defsave(self,f):"""Save the local state_dict to disk on per rank. Saved dict contains 2 parts: * 'params': hyper parameters of the optimizer. * 'emb_states': partial optimizer states, each embedding contains 2 items: 1. ```ids```: global id of the nodes/edges stored in this rank. 2. ```states```: state data corrseponding to ```ids```. NOTE: This needs to be called on all ranks. Parameters ---------- f : Union[str, os.PathLike] The path of the file to save to. See Also -------- load """ifself._world_size>1:th.distributed.barrier()f=fifisinstance(f,str)elsestr(f,"UTF-8")f=f"{f}_{self._rank}"th.save(self.local_state_dict(),f)ifself._world_size>1:th.distributed.barrier()defload(self,f):"""Load the local state of the optimizer from the file on per rank. NOTE: This needs to be called on all ranks. Parameters ---------- f : Union[str, os.PathLike] The path of the file to load from. See Also -------- save """ifself._world_size>1:th.distributed.barrier()f=fifisinstance(f,str)elsestr(f,"UTF-8")f_attach_rank=f"{f}_{self._rank}"# Don't throw error here to support device number scale-out# after reloading, but make sure your hyper parameter is same# as before because new added local optimizers will be filled# in nothingifnotexists(f_attach_rank):warnings.warn(f"File {f_attach_rank} can't be found, load nothing.")else:old_world_size=self._load_state_from(f_attach_rank)# Device number scale-inifself._world_size<old_world_size:forrankinrange(self._rank+self._world_size,old_world_size,self._world_size,):self._load_state_from(f"{f}_{rank}")ifself._world_size>1:th.distributed.barrier()def_load_state_from(self,f):local_state_dict=th.load(f)world_size=local_state_dict[PARAMS].pop(WORLD_SIZE)self.load_local_state_dict(local_state_dict)returnworld_sizedef_get_local_ids(self,part_policy):ifEDGE_PART_POLICYinpart_policy.policy_str:returnpart_policy.partition_book.partid2eids(part_policy.part_id,part_policy.type_name)elifNODE_PART_POLICYinpart_policy.policy_str:returnpart_policy._partition_book.partid2nids(part_policy.part_id,part_policy.type_name)else:raiseRuntimeError("Cannot support policy: %s "%part_policy.policy_str)defstep(self):"""The step function. The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL's servers and update the embeddings. """withth.no_grad():device=(th.device(f"cuda:{self._rank}")ifth.distributed.get_backend()=="nccl"elseth.device("cpu"))local_indics={emb.name:[]forembinself._params}local_grads={emb.name:[]forembinself._params}forembinself._params:name=emb.weight.namekvstore=emb.weight.kvstoretrainers_per_server=self._world_size//kvstore.num_serversidics=[]grads=[]fortraceinemb._trace:iftrace[1].gradisnotNone:idics.append(trace[0])grads.append(trace[1].grad.data)else:assertlen(trace[0])==0# If the sparse embedding is not used in the previous forward step# The idx and grad will be empty, initialize them as empty tensors to# avoid crashing the optimizer step logic.## Note: we cannot skip the gradient exchange and update steps as other# working processes may send gradient update requests corresponding# to certain embedding to this process.idics=(th.cat(idics,dim=0)iflen(idics)!=0elseth.zeros((0,),dtype=th.long,device=th.device("cpu")))grads=(th.cat(grads,dim=0)iflen(grads)!=0elseth.zeros((0,emb.embedding_dim),dtype=th.float32,device=th.device("cpu"),))device=grads.device# will send grad to each corresponding trainerifself._world_size>1:# get idx split from kvstoreidx_split=kvstore.get_partid(emb.data_name,idics)idx_split_size=[]idics_list=[]grad_list=[]# split idx and grad firstforiinrange(kvstore.num_servers):mask=idx_split==iidx_i=idics[mask]grad_i=grads[mask]iftrainers_per_server<=1:idx_split_size.append(th.tensor([idx_i.shape[0]],dtype=th.int64,device=device,))idics_list.append(idx_i)grad_list.append(grad_i)else:kv_idx_split=th.remainder(idx_i,trainers_per_server).long()forjinrange(trainers_per_server):mask=kv_idx_split==jidx_j=idx_i[mask]grad_j=grad_i[mask]idx_split_size.append(th.tensor([idx_j.shape[0]],dtype=th.int64,device=device,))idics_list.append(idx_j)grad_list.append(grad_j)# if one machine launch multiple KVServer, they share the same storage.# For each machine, the pytorch rank is num_trainers *# machine_id + i# use scatter to sync across trainers about the p2p tensor size# Note: If we have GPU nccl support, we can use all_to_all to# sync information heregather_list=list(th.empty([self._world_size],dtype=th.int64,device=device).chunk(self._world_size))alltoall(self._rank,self._world_size,gather_list,idx_split_size,device,)idx_gather_list=[th.empty((int(num_emb),),dtype=idics.dtype,device=device)fornum_embingather_list]alltoallv(self._rank,self._world_size,idx_gather_list,idics_list,device,)local_indics[name]=idx_gather_listgrad_gather_list=[th.empty((int(num_emb),grads.shape[1]),dtype=grads.dtype,device=device,)fornum_embingather_list]alltoallv(self._rank,self._world_size,grad_gather_list,grad_list,device,)local_grads[name]=grad_gather_listelse:local_indics[name]=[idics]local_grads[name]=[grads]ifself._clean_grad:# clean gradient trackforembinself._params:emb.reset_trace()self._clean_grad=False# do local updateforembinself._params:name=emb.weight.nameidx=th.cat(local_indics[name],dim=0)grad=th.cat(local_grads[name],dim=0)self.update(idx.to(device,non_blocking=True),grad.to(device,non_blocking=True),emb,)# synchronized gradient updateifself._world_size>1:th.distributed.barrier()@abstractmethoddefupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.distributed.DistEmbedding Sparse node embedding to update. """defzero_grad(self):"""clean grad cache"""self._clean_grad=Truedefinitializer(shape,dtype):"""Sparse optimizer state initializer Parameters ---------- shape : tuple of ints The shape of the state tensor dtype : torch dtype The data type of the state tensor """arr=th.zeros(shape,dtype=dtype)returnarr
[docs]classSparseAdagrad(DistSparseGradOptimizer):r"""Distributed Node embedding optimizer using the Adagrad algorithm. This optimizer implements a distributed sparse version of Adagrad algorithm for optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings. Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where :math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i` at step :math:`t`. NOTE: The support of sparse Adagrad optimizer is experimental. Parameters ---------- params : list[dgl.distributed.DistEmbedding] The list of dgl.distributed.DistEmbedding. lr : float The learning rate. eps : float, Optional The term added to the denominator to improve numerical stability Default: 1e-10 """def__init__(self,params,lr,eps=1e-10):super(SparseAdagrad,self).__init__(params,lr)self._eps=epsself._defaults={"_lr":lr,"_eps":eps}# We need to register a state sum for each embedding in the kvstore.forembinparams:assertisinstance(emb,DistEmbedding),"SparseAdagrad only supports dgl.distributed.DistEmbedding"name=emb.name+"_sum"state=DistTensor((emb.num_embeddings,emb.embedding_dim),th.float32,name,init_func=initializer,part_policy=emb.part_policy,is_gdata=False,)assert(emb.namenotinself._state),"{} already registered in the optimizer".format(emb.name)self._state[emb.name]=statedefupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.distributed.DistEmbedding Sparse embedding to update. """eps=self._epsclr=self._lrstate_dev=th.device("cpu")exec_dev=grad.device# only perform async copies cpu -> gpu, or gpu-> gpu, but block# when copying to the cpu, so as to ensure the copy is finished# before operating on the data on the cpustate_block=state_dev==th.device("cpu")andexec_dev!=state_dev# the update is non-linear so indices must be uniquegrad_indices,inverse,cnt=th.unique(idx,return_inverse=True,return_counts=True)grad_values=th.zeros((grad_indices.shape[0],grad.shape[1]),device=exec_dev)grad_values.index_add_(0,inverse,grad)grad_values=grad_values/cnt.unsqueeze(1)grad_sum=grad_values*grad_values# update grad stategrad_state=self._state[emb.name][grad_indices].to(exec_dev)grad_state+=grad_sumgrad_state_dst=grad_state.to(state_dev,non_blocking=True)ifstate_block:# use events to try and overlap CPU and GPU as much as possibleupdate_event=th.cuda.Event()update_event.record()# update embstd_values=grad_state.sqrt_().add_(eps)tmp=clr*grad_values/std_valuestmp_dst=tmp.to(state_dev,non_blocking=True)ifstate_block:std_event=th.cuda.Event()std_event.record()# wait for our transfers from exec_dev to state_dev to finish# before we can use themupdate_event.wait()self._state[emb.name][grad_indices]=grad_state_dstifstate_block:# wait for the transfer of std_values to finish before we# can use itstd_event.wait()emb._tensor[grad_indices]-=tmp_dst
[docs]classSparseAdam(DistSparseGradOptimizer):r"""Distributed Node embedding optimizer using the Adam algorithm. This optimizer implements a distributed sparse version of Adam algorithm for optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings. Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter in the embeddings, where :math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`, :math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`, :math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i` at step :math:`t`. NOTE: The support of sparse Adam optimizer is experimental. Parameters ---------- params : list[dgl.distributed.DistEmbedding] The list of dgl.distributed.DistEmbedding. lr : float The learning rate. betas : tuple[float, float], Optional Coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999) eps : float, Optional The term added to the denominator to improve numerical stability Default: 1e-8 """def__init__(self,params,lr,betas=(0.9,0.999),eps=1e-08):super(SparseAdam,self).__init__(params,lr)self._eps=eps# We need to register a state sum for each embedding in the kvstore.self._beta1=betas[0]self._beta2=betas[1]self._defaults={"_lr":lr,"_eps":eps,"_beta1":betas[0],"_beta2":betas[1],}forembinparams:assertisinstance(emb,DistEmbedding),"SparseAdam only supports dgl.distributed.DistEmbedding"state_step=DistTensor((emb.num_embeddings,),th.float32,emb.name+"_step",init_func=initializer,part_policy=emb.part_policy,is_gdata=False,)state_mem=DistTensor((emb.num_embeddings,emb.embedding_dim),th.float32,emb.name+"_mem",init_func=initializer,part_policy=emb.part_policy,is_gdata=False,)state_power=DistTensor((emb.num_embeddings,emb.embedding_dim),th.float32,emb.name+"_power",init_func=initializer,part_policy=emb.part_policy,is_gdata=False,)state=(state_step,state_mem,state_power)assert(emb.namenotinself._state),"{} already registered in the optimizer".format(emb.name)self._state[emb.name]=statedefupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.distributed.DistEmbedding Sparse embedding to update. """beta1=self._beta1beta2=self._beta2eps=self._epsclr=self._lrstate_step,state_mem,state_power=self._state[emb.name]state_dev=th.device("cpu")exec_dev=grad.device# only perform async copies cpu -> gpu, or gpu-> gpu, but block# when copying to the cpu, so as to ensure the copy is finished# before operating on the data on the cpustate_block=state_dev==th.device("cpu")andexec_dev!=state_dev# the update is non-linear so indices must be uniquegrad_indices,inverse,cnt=th.unique(idx,return_inverse=True,return_counts=True)# update grad statestate_idx=grad_indices.to(state_dev)# The original implementation will cause read/write contension.# state_step[state_idx] += 1# state_step = state_step[state_idx].to(exec_dev, non_blocking=True)# In a distributed environment, the first line of code will send write requests to# kvstore servers to update the state_step which is asynchronous and the second line# of code will also send read requests to kvstore servers. The write and read requests# may be handled by different kvstore servers managing the same portion of the# state_step dist tensor in the same node. So that, the read request may read an old# value (i.e., 0 in the first iteration) which will cause# update_power_corr to be NaNstate_val=state_step[state_idx]+1state_step[state_idx]=state_valstate_step=state_val.to(exec_dev)orig_mem=state_mem[state_idx].to(exec_dev)orig_power=state_power[state_idx].to(exec_dev)grad_values=th.zeros((grad_indices.shape[0],grad.shape[1]),device=exec_dev)grad_values.index_add_(0,inverse,grad)grad_values=grad_values/cnt.unsqueeze(1)grad_mem=grad_valuesgrad_power=grad_values*grad_valuesupdate_mem=beta1*orig_mem+(1.0-beta1)*grad_memupdate_power=beta2*orig_power+(1.0-beta2)*grad_powerupdate_mem_dst=update_mem.to(state_dev,non_blocking=True)update_power_dst=update_power.to(state_dev,non_blocking=True)ifstate_block:# use events to try and overlap CPU and GPU as much as possibleupdate_event=th.cuda.Event()update_event.record()update_mem_corr=update_mem/(1.0-th.pow(th.tensor(beta1,device=exec_dev),state_step)).unsqueeze(1)update_power_corr=update_power/(1.0-th.pow(th.tensor(beta2,device=exec_dev),state_step)).unsqueeze(1)std_values=clr*update_mem_corr/(th.sqrt(update_power_corr)+eps)std_values_dst=std_values.to(state_dev,non_blocking=True)ifstate_block:std_event=th.cuda.Event()std_event.record()# wait for our transfers from exec_dev to state_dev to finish# before we can use themupdate_event.wait()state_mem[state_idx]=update_mem_dststate_power[state_idx]=update_power_dstifstate_block:# wait for the transfer of std_values to finish before we# can use itstd_event.wait()emb._tensor[state_idx]-=std_values_dst