"""Node embedding optimizers"""importabcfromabcimportabstractmethodimporttorchasthfrom...cudaimportncclfrom...nn.pytorchimportNodeEmbeddingfrom...partitionimportNDArrayPartitionfrom...utilsimport(create_shared_mem_array,gather_pinned_tensor_rows,get_shared_mem_array,pin_memory_inplace,scatter_pinned_tensor_rows,)classSparseGradOptimizer(abc.ABC):r"""The abstract sparse optimizer. Note: dgl sparse optimizer only work with dgl.NodeEmbedding Parameters ---------- params : list of NodeEmbedding The list of NodeEmbeddings. lr : float The learning rate. """def__init__(self,params,lr):self._params=paramsself._lr=lrself._rank=Noneself._world_size=Noneself._shared_cache={}self._clean_grad=Falseself._opt_meta={}self._comm=Noneself._first_step=Trueself._device=None# hold released shared memory to let other process to munmap it first# otherwise it will crash the trainingself.shmem_buffer_holder=[]assertlen(params)>0,"Empty parameters"# if we are using shared memory for communicationforembinparams:assertisinstance(emb,NodeEmbedding),"DGL SparseOptimizer only supports dgl.nn.NodeEmbedding"ifself._rankisNone:self._rank=emb.rankself._world_size=emb.world_sizeelse:assert(self._rank==emb.rank),"MultiGPU rank for each embedding should be same."assert(self._world_size==emb.world_size),"MultiGPU world_size for each embedding should be same."assertnotself._rankisNoneassertnotself._world_sizeisNonedefstep(self):"""The step function. The step function is invoked at the end of every batch to update embeddings """# on the first step, check to see if the grads are on the GPUifself._first_step:forembinself._params:for_,datainemb._trace:ifdata.grad.device.type=="cuda":# create a communicatorifself._device:assert(self._device==data.grad.device),"All gradients must be on the same device"else:self._device=data.grad.deviceelse:assert(notself._device),"All gradients must be on the same device"# distributed backend use ncclifself._deviceand(notth.distributed.is_initialized()orth.distributed.get_backend()=="nccl"):# device is only set if the grads are on a GPUself._comm_setup()else:self._shared_setup()self._first_step=Falseifself._comm:self._comm_step()else:self._shared_step()@abstractmethoddefsetup(self,params):"""This is function where subclasses can perform any setup they need to. It will be called during the first step, and communicators or shared memory will have been setup before this call. Parameters ---------- params : list of NodeEmbedding The list of NodeEmbeddings. """def_comm_setup(self):self._comm=Truedef_shared_setup(self):forembinself._params:emb_name=emb.nameifself._rank==0:# the master gpu processopt_meta=create_shared_mem_array(emb_name+"_opt_meta",(self._world_size,self._world_size),th.int32,).zero_()ifself._rank==0:emb.store.set(emb_name+"_opt_meta",emb_name)self._opt_meta[emb_name]=opt_metaelifself._rank>0:# receiveemb.store.wait([emb_name+"_opt_meta"])opt_meta=get_shared_mem_array(emb_name+"_opt_meta",(self._world_size,self._world_size),th.int32,)self._opt_meta[emb_name]=opt_metadef_comm_step(self):withth.no_grad():idx_in={}grad_in={}forembinself._params:# pylint: disable=too-many-nested-blocksemb_name=emb.namepartition=emb.partitionifnotpartition:# use default partitioningpartition=NDArrayPartition(emb.num_embeddings,self._world_sizeifself._world_size>0else1,mode="remainder",)# we need to combine gradients from multiple forward pathsiflen(emb._trace)==0:idx=th.zeros((0,),dtype=th.long,device=self._device)grad=th.zeros((0,emb.embedding_dim),dtype=th.float32,device=self._device,)eliflen(emb._trace)==1:# the special case where we can use the tensors as is# without any memcpy'sidx,grad=emb._trace[0]grad=grad.grad.dataelse:idx=[]grad=[]fori,datainemb._trace:idx.append(i)grad.append(data.grad.data)idx=th.cat(idx,dim=0)grad=th.cat(grad,dim=0)(idx_in[emb_name],grad_in[emb_name],)=nccl.sparse_all_to_all_push(idx,grad,partition=partition)ifemb.partition:# if the embedding is partitioned, map back to indexes# into the local tensoridx_in[emb_name]=partition.map_to_local(idx_in[emb_name])ifself._clean_grad:# clean gradient trackforembinself._params:emb.reset_trace()self._clean_grad=Falseforembinself._params:emb_name=emb.nameidx=idx_in[emb_name]grad=grad_in[emb_name]self.update(idx,grad,emb)def_shared_step(self):withth.no_grad():# Frequently alloc and free shared memory to hold intermediate tensor is expensive# We cache shared memory buffers in shared_emb.shared_emb={emb.name:([],[])forembinself._params}# Go through all sparse embeddingsforembinself._params:# pylint: disable=too-many-nested-blocksemb_name=emb.name# we need to combine gradients from multiple forward pathsidx=[]grad=[]fori,datainemb._trace:idx.append(i)grad.append(data.grad.data)# If the sparse embedding is not used in the previous forward step# The idx and grad will be empty, initialize them as empty tensors to# avoid crashing the optimizer step logic.## Note: we cannot skip the gradient exchange and update steps as other# working processes may send gradient update requests corresponding# to certain embedding to this process.idx=(th.cat(idx,dim=0)iflen(idx)!=0elseth.zeros((0,),dtype=th.long,device=th.device("cpu")))grad=(th.cat(grad,dim=0)iflen(grad)!=0elseth.zeros((0,emb.embedding_dim),dtype=th.float32,device=th.device("cpu"),))device=grad.deviceidx_dtype=idx.dtypegrad_dtype=grad.dtypegrad_dim=grad.shape[1]ifself._world_size>1:ifemb_namenotinself._shared_cache:self._shared_cache[emb_name]={}# Each training process takes the resposibility of updating a range# of node embeddings, thus we can parallel the gradient update.# The overall progress includes:# 1. In each training process:# 1.a Deciding which process a node embedding belongs to according# to the formula: process_id = node_idx mod num_of_process(N)# 1.b Split the node index tensor and gradient tensor into N parts# according to step 1.# 1.c Write each node index sub-tensor and gradient sub-tensor into# different DGL shared memory buffers.# 2. Cross training process synchronization# 3. In each traning process:# 3.a Collect node index sub-tensors and gradient sub-tensors# 3.b Do gradient update# 4. Doneidx_split=th.remainder(idx,self._world_size).long()foriinrange(self._world_size):mask=idx_split==iidx_i=idx[mask]grad_i=grad[mask]ifi==self._rank:shared_emb[emb_name][0].append(idx_i)shared_emb[emb_name][1].append(grad_i)else:# currently nccl does not support Alltoallv operation# we need to use CPU shared memory to share gradient# across processesidx_i=idx_i.to(th.device("cpu"))grad_i=grad_i.to(th.device("cpu"))idx_shmem_name="idx_{}_{}_{}".format(emb_name,self._rank,i)grad_shmem_name="grad_{}_{}_{}".format(emb_name,self._rank,i)# Create shared memory to hold temporary index and gradient tensor for# cross-process send and recv.if(idx_shmem_namenotinself._shared_cache[emb_name]orself._shared_cache[emb_name][idx_shmem_name].shape[0]<idx_i.shape[0]):if(idx_shmem_nameinself._shared_cache[emb_name]):self.shmem_buffer_holder.append(self._shared_cache[emb_name][idx_shmem_name])self.shmem_buffer_holder.append(self._shared_cache[emb_name][grad_shmem_name])# The total number of buffers is the number of NodeEmbeddings *# world_size * (world_size - 1). The minimun buffer size is 128.## We extend the buffer by idx_i.shape[0] * 2 to avoid# frequent shared memory allocation.# The overall buffer cost will be smaller than three times# the maximum memory requirement for sharing gradients.buffer_size=(128ifidx_i.shape[0]<128elseidx_i.shape[0]*2)idx_shmem=create_shared_mem_array("{}_{}".format(idx_shmem_name,buffer_size),(buffer_size,),idx_dtype,)grad_shmem=create_shared_mem_array("{}_{}".format(grad_shmem_name,buffer_size),(buffer_size,grad_dim),grad_dtype,)self._shared_cache[emb_name][idx_shmem_name]=idx_shmemself._shared_cache[emb_name][grad_shmem_name]=grad_shmem# Fill shared memory with temporal index tensor and gradient tensorself._shared_cache[emb_name][idx_shmem_name][:idx_i.shape[0]]=idx_iself._shared_cache[emb_name][grad_shmem_name][:idx_i.shape[0]]=grad_iself._opt_meta[emb_name][self._rank][i]=idx_i.shape[0]else:shared_emb[emb_name][0].append(idx)shared_emb[emb_name][1].append(grad)# make sure the idx shape is passed to each process through opt_metaifself._world_size>1:th.distributed.barrier()forembinself._params:# pylint: disable=too-many-nested-blocksemb_name=emb.nameifself._world_size>1:# The first element in shared_emb[emb_name][0] is the local idxdevice=shared_emb[emb_name][0][0].device# gather gradients from all other processesforiinrange(self._world_size):ifi!=self._rank:idx_shmem_name="idx_{}_{}_{}".format(emb_name,i,self._rank)grad_shmem_name="grad_{}_{}_{}".format(emb_name,i,self._rank)size=self._opt_meta[emb_name][i][self._rank]# Retrive shared memory holding the temporal index and gradient# tensor that is sent to current training processif(idx_shmem_namenotinself._shared_cache[emb_name]orself._shared_cache[emb_name][idx_shmem_name].shape[0]<size):buffer_size=128ifsize<128elsesize*2idx_shmem=get_shared_mem_array("{}_{}".format(idx_shmem_name,buffer_size),(buffer_size,),idx_dtype,)grad_shmem=get_shared_mem_array("{}_{}".format(grad_shmem_name,buffer_size),(buffer_size,grad_dim),grad_dtype,)self._shared_cache[emb_name][idx_shmem_name]=idx_shmemself._shared_cache[emb_name][grad_shmem_name]=grad_shmemidx_i=self._shared_cache[emb_name][idx_shmem_name][:size]grad_i=self._shared_cache[emb_name][grad_shmem_name][:size]shared_emb[emb_name][0].append(idx_i.to(device,non_blocking=True))shared_emb[emb_name][1].append(grad_i.to(device,non_blocking=True))ifself._clean_grad:# clean gradient trackforembinself._params:emb.reset_trace()self._clean_grad=Falseforembinself._params:emb_name=emb.nameidx=th.cat(shared_emb[emb_name][0],dim=0)grad=th.cat(shared_emb[emb_name][1],dim=0)self.update(idx,grad,emb)# synchronized gradient updateifself._world_size>1:th.distributed.barrier()@abstractmethoddefupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.nn.NodeEmbedding Sparse node embedding to update. """defzero_grad(self):"""clean grad cache"""self._clean_grad=Truedefstate_dict(self,**kwargs):# pylint: disable=unused-argument"""Return a copy of the whole optimizer states stored in CPU memory. If this is a multi-processing instance, the states will be returned in shared memory. If the underlying embedding is currently stored on multiple GPUs, all processes must call this method in the same order. NOTE: This method must be called by all processes sharing the underlying embedding, or it may result in a deadlock. Returns ------- dictionary of optimizer states The optimizer states stored in CPU memory. """return{"state":{emb.name:emb._all_get_optm_state()forembinself._params},"param_groups":self.param_groups,}defload_state_dict(self,state_dict,**kwargs):# pylint: disable=unused-argument"""Load the optimizer states. This method must be called by all processes sharing the underlying embedding with identical :attr:`state_dict`. NOTE: This method must be called by all processes sharing the underlying embedding, or it may result in a deadlock. Parameters ---------- state_dict : dictionary of optimizer states The global states to pull values from. """forembinself._params:emb._all_set_optm_state(state_dict["state"][emb.name])self._set_param_groups(state_dict["param_groups"])@property@abstractmethoddefparam_groups(self):"""Emulate 'param_groups' of torch.optim.Optimizer. Different from that, the returned 'param_groups' doesn't contain parameters because getting the whole embedding is very expensive. It contains other attributes, e.g., lr, eps, for debugging. """@abstractmethoddef_set_param_groups(self,groups):"""A helper method to load param_groups from saved state_dict."""
[docs]classSparseAdagrad(SparseGradOptimizer):r"""Node embedding optimizer using the Adagrad algorithm. This optimizer implements a sparse version of Adagrad algorithm for optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings. Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where :math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i` at step :math:`t`. NOTE: The support of sparse Adagrad optimizer is experimental. Parameters ---------- params : list[dgl.nn.NodeEmbedding] The list of dgl.nn.NodeEmbedding. lr : float The learning rate. eps : float, Optional The term added to the denominator to improve numerical stability Default: 1e-10 Examples -------- >>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdagrad([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step() """def__init__(self,params,lr,eps=1e-10):super(SparseAdagrad,self).__init__(params,lr)self._eps=eps# setup tensors for optimizer statesself.setup(self._params)defsetup(self,params):# We need to register a state sum for each embedding in the kvstore.forembinparams:assertisinstance(emb,NodeEmbedding),"SparseAdagrad only supports dgl.nn.NodeEmbedding"emb_name=emb.nameifth.device(emb.weight.device)==th.device("cpu"):# if our embedding is on the CPU, our state also has to beifself._rank<0:state=th.empty(emb.weight.shape,dtype=th.float32,device=th.device("cpu"),).zero_()elifself._rank==0:state=create_shared_mem_array(emb_name+"_state",emb.weight.shape,th.float32).zero_()ifself._world_size>1:emb.store.set(emb_name+"_opt",emb_name)elifself._rank>0:# receiveemb.store.wait([emb_name+"_opt"])state=get_shared_mem_array(emb_name+"_state",emb.weight.shape,th.float32)else:# distributed state on on gpustate=th.empty(emb.weight.shape,dtype=th.float32,device=emb.weight.device,).zero_()emb.set_optm_state((state,))defupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.nn.NodeEmbedding Sparse embedding to update. """eps=self._epsclr=self._lr# the update is non-linear so indices must be uniquegrad_indices,inverse,cnt=th.unique(idx,return_inverse=True,return_counts=True)grad_values=th.zeros((grad_indices.shape[0],grad.shape[1]),device=grad.device)grad_values.index_add_(0,inverse,grad)grad_values=grad_values/cnt.unsqueeze(1)grad_sum=grad_values*grad_values(state,)=emb.optm_statestate_dev=state.devicestate_idx=grad_indices.to(state_dev)grad_state=state[state_idx].to(grad.device)grad_state+=grad_sumstate[state_idx]=grad_state.to(state_dev)std_values=grad_state.add_(eps).sqrt_()tmp=clr*grad_values/std_valuesemb.weight[state_idx]-=tmp.to(state_dev)@propertydefparam_groups(self):"""Emulate 'param_groups' of torch.optim.Optimizer. Different from that, the returned 'param_groups' doesn't contain parameters because getting the whole embedding is very expensive. It contains other attributes, e.g., lr, eps, for debugging. """return[{"lr":self._lr,"eps":self._eps}]def_set_param_groups(self,groups):"""A helper method to load param_groups from saved state_dict."""self._lr=groups[0]["lr"]self._eps=groups[0]["eps"]
[docs]classSparseAdam(SparseGradOptimizer):r"""Node embedding optimizer using the Adam algorithm. This optimizer implements a sparse version of Adagrad algorithm for optimizing :class:`dgl.nn.NodeEmbedding`. Being sparse means it only updates the embeddings whose gradients have updates, which are usually a very small portion of the total embeddings. Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter in the embeddings, where :math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`, :math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`, :math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and :math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i` at step :math:`t`. NOTE: The support of sparse Adam optimizer is experimental. Parameters ---------- params : list[dgl.nn.NodeEmbedding] The list of dgl.nn.NodeEmbeddings. lr : float The learning rate. betas : tuple[float, float], Optional Coefficients used for computing running averages of gradient and its square. Default: (0.9, 0.999) eps : float, Optional The term added to the denominator to improve numerical stability Default: 1e-8 use_uva : bool, Optional Whether to use pinned memory for storing 'mem' and 'power' parameters, when the embedding is stored on the CPU. This will improve training speed, but will require locking a large number of virtual memory pages. For embeddings which are stored in GPU memory, this setting will have no effect. Default: True if the gradients are generated on the GPU, and False if the gradients are on the CPU. dtype : torch.dtype, Optional The type to store optimizer state with. Default: th.float32. Examples -------- >>> def initializer(emb): th.nn.init.xavier_uniform_(emb) return emb >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer) >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001) >>> for blocks in dataloader: ... ... ... feats = emb(nids, gpu_0) ... loss = F.sum(feats + 1, 0) ... loss.backward() ... optimizer.step() """def__init__(self,params,lr,betas=(0.9,0.999),eps=1e-08,use_uva=None,dtype=th.float32,):super(SparseAdam,self).__init__(params,lr)self._lr=lrself._beta1=betas[0]self._beta2=betas[1]self._eps=epsself._use_uva=use_uvaself._nd_handle={}self._is_using_uva={}assertdtypein[th.float16,th.float32],("Unsupported dtype {}. Valid choices are th.float32 ""and th.float32".format(dtype))self._dtype=dtype# setup tensors for optimizer statesself.setup(self._params)def_setup_uva(self,name,mem,power):self._is_using_uva[name]=Truemem_nd=pin_memory_inplace(mem)power_nd=pin_memory_inplace(power)self._nd_handle[name]=[mem_nd,power_nd]defsetup(self,params):# We need to register a state sum for each embedding in the kvstore.forembinparams:assertisinstance(emb,NodeEmbedding),"SparseAdam only supports dgl.nn.NodeEmbedding"emb_name=emb.nameself._is_using_uva[emb_name]=self._use_uvaifth.device(emb.weight.device)==th.device("cpu"):# if our embedding is on the CPU, our state also has to beifself._rank<0:state_step=th.empty((emb.weight.shape[0],),dtype=th.int32,device=th.device("cpu"),).zero_()state_mem=th.empty(emb.weight.shape,dtype=self._dtype,device=th.device("cpu"),).zero_()state_power=th.empty(emb.weight.shape,dtype=self._dtype,device=th.device("cpu"),).zero_()elifself._rank==0:state_step=create_shared_mem_array(emb_name+"_step",(emb.weight.shape[0],),th.int32).zero_()state_mem=create_shared_mem_array(emb_name+"_mem",emb.weight.shape,self._dtype).zero_()state_power=create_shared_mem_array(emb_name+"_power",emb.weight.shape,self._dtype).zero_()ifself._world_size>1:emb.store.set(emb_name+"_opt",emb_name)elifself._rank>0:# receiveemb.store.wait([emb_name+"_opt"])state_step=get_shared_mem_array(emb_name+"_step",(emb.weight.shape[0],),th.int32)state_mem=get_shared_mem_array(emb_name+"_mem",emb.weight.shape,self._dtype)state_power=get_shared_mem_array(emb_name+"_power",emb.weight.shape,self._dtype)ifself._is_using_uva[emb_name]:# if use_uva has been explicitly set to true, otherwise# wait until first step to decideself._setup_uva(emb_name,state_mem,state_power)else:# make sure we don't use UVA when data is on the GPUself._is_using_uva[emb_name]=False# distributed state on on gpustate_step=th.empty([emb.weight.shape[0]],dtype=th.int32,device=emb.weight.device,).zero_()state_mem=th.empty(emb.weight.shape,dtype=self._dtype,device=emb.weight.device,).zero_()state_power=th.empty(emb.weight.shape,dtype=self._dtype,device=emb.weight.device,).zero_()state=(state_step,state_mem,state_power)emb.set_optm_state(state)defupdate(self,idx,grad,emb):"""Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated. grad : tensor Gradient of each embedding. emb : dgl.nn.NodeEmbedding Sparse embedding to update. """withth.no_grad():state_step,state_mem,state_power=emb.optm_stateexec_dtype=grad.dtypeexec_dev=grad.devicestate_dev=state_step.device# whether or not we need to transfer data from the GPU to the CPU# while updating the weightsis_d2h=state_dev.type=="cpu"andexec_dev.type=="cuda"# only perform async copies cpu -> gpu, or gpu-> gpu, but block# when copying to the cpu, so as to ensure the copy is finished# before operating on the data on the cpustate_block=is_d2hifself._is_using_uva[emb.name]isNoneandis_d2h:# we should use UVA going forwardself._setup_uva(emb.name,state_mem,state_power)elifself._is_using_uva[emb.name]isNone:# we shouldn't use UVA going forwardself._is_using_uva[emb.name]=Falseuse_uva=self._is_using_uva[emb.name]beta1=self._beta1beta2=self._beta2eps=self._epsclr=self._lr# There can be duplicated indices due to sampling.# Thus unique them here and average the gradient here.grad_indices,inverse,cnt=th.unique(idx,return_inverse=True,return_counts=True)state_idx=grad_indices.to(state_dev)state_step[state_idx]+=1state_step=state_step[state_idx].to(exec_dev)ifuse_uva:orig_mem=gather_pinned_tensor_rows(state_mem,grad_indices)orig_power=gather_pinned_tensor_rows(state_power,grad_indices)else:orig_mem=state_mem[state_idx].to(exec_dev)orig_power=state_power[state_idx].to(exec_dev)# convert to exec dtypeorig_mem=orig_mem.to(dtype=exec_dtype)orig_power=orig_power.to(dtype=exec_dtype)grad_values=th.zeros((grad_indices.shape[0],grad.shape[1]),device=exec_dev)grad_values.index_add_(0,inverse,grad)grad_values=grad_values/cnt.unsqueeze(1)grad_mem=grad_valuesgrad_power=grad_values*grad_valuesupdate_mem=beta1*orig_mem+(1.0-beta1)*grad_memupdate_power=beta2*orig_power+(1.0-beta2)*grad_powerifuse_uva:scatter_pinned_tensor_rows(state_mem,grad_indices,update_mem.to(dtype=self._dtype))scatter_pinned_tensor_rows(state_power,grad_indices,update_power.to(dtype=self._dtype),)else:update_mem_dst=update_mem.to(dtype=self._dtype).to(state_dev,non_blocking=True)update_power_dst=update_power.to(dtype=self._dtype).to(state_dev,non_blocking=True)ifstate_block:# use events to try and overlap CPU and GPU as much as possibleupdate_event=th.cuda.Event()update_event.record()update_mem_corr=update_mem/(1.0-th.pow(th.tensor(beta1,device=exec_dev),state_step)).unsqueeze(1)update_power_corr=update_power/(1.0-th.pow(th.tensor(beta2,device=exec_dev),state_step)).unsqueeze(1)std_values=(clr*update_mem_corr/(th.sqrt(update_power_corr)+eps))std_values_dst=std_values.to(state_dev,non_blocking=True)ifstate_block:std_event=th.cuda.Event()std_event.record()ifnotuse_uva:ifstate_block:# wait for our transfers from exec_dev to state_dev to finish# before we can use themupdate_event.wait()state_mem[state_idx]=update_mem_dststate_power[state_idx]=update_power_dstifstate_block:# wait for the transfer of std_values to finish before we# can use itstd_event.wait()emb.weight[state_idx]-=std_values_dst@propertydefparam_groups(self):"""Emulate 'param_groups' of torch.optim.Optimizer. Different from that, the returned 'param_groups' doesn't contain parameters because getting the whole embedding is very expensive. It contains other attributes, e.g., lr, betas, eps, for debugging. """return[{"lr":self._lr,"betas":(self._beta1,self._beta2),"eps":self._eps,}]def_set_param_groups(self,groups):"""A helper method to load param_groups from saved state_dict."""self._lr=groups[0]["lr"]self._beta1,self._beta2=groups[0]["betas"]self._eps=groups[0]["eps"]