"""PyTorch multiprocessing wrapper."""importrandomimporttracebackfrom_threadimportstart_new_threadfromfunctoolsimportwrapsimporttorchimporttorch.multiprocessingasmpfrom..utilsimportcreate_shared_mem_array,get_shared_mem_arraydefthread_wrapped_func(func):""" Wraps a process entry point to make it work with OpenMP. """@wraps(func)defdecorated_function(*args,**kwargs):queue=mp.Queue()def_queue_result():exception,trace,res=None,None,Nonetry:res=func(*args,**kwargs)exceptExceptionase:# pylint: disable=broad-exceptexception=etrace=traceback.format_exc()queue.put((res,exception,trace))start_new_thread(_queue_result,())result,exception,trace=queue.get()ifexceptionisNone:returnresultelse:assertisinstance(exception,Exception)raiseexception.__class__(trace)returndecorated_function# pylint: disable=missing-docstringclassProcess(mp.Process):# pylint: disable=dangerous-default-valuedef__init__(self,group=None,target=None,name=None,args=(),kwargs={},*,daemon=None):target=thread_wrapped_func(target)super().__init__(group,target,name,args,kwargs,daemon=daemon)def_get_shared_mem_name(id_):return"shared"+str(id_)
[docs]defcall_once_and_share(func,shape,dtype,rank=0):"""Invoke the function in a single process of the PyTorch distributed process group, and share the result with other processes. Parameters ---------- func : callable Any callable that accepts no arguments and returns an arbitrary object. shape : tuple[int] The shape of the shared tensor. Must match the output of :attr:`func`. dtype : torch.dtype The data type of the shared tensor. Must match the output of :attr:`func`. rank : int, optional The process ID to actually execute the function. """current_rank=torch.distributed.get_rank()dist_buf=torch.LongTensor([1])iftorch.distributed.get_backend()=="nccl":# Use .cuda() to transfer it to the correct device. Should be OK since# PyTorch recommends the users to call set_device() after getting inside# torch.multiprocessing.spawn()dist_buf=dist_buf.cuda()# Process with the given rank creates and populates the shared memory array.ifcurrent_rank==rank:# PyTorch Lightning 1.6+ seems to set the random seed during process spawning# to the same seed value.random_=random.Random()id_=random_.getrandbits(32)name=_get_shared_mem_name(id_)result=create_shared_mem_array(name,shape,dtype)result[:]=func()dist_buf[0]=id_# Broadcasts the name of the shared array to other processes.torch.distributed.broadcast(dist_buf,rank)# If no exceptions, other processes open the same shared memory object.ifcurrent_rank!=rank:id_=dist_buf.item()name=_get_shared_mem_name(id_)result=get_shared_mem_array(name,shape,dtype)returnresult
[docs]defshared_tensor(shape,dtype=torch.float32):"""Create a tensor in shared memory accessible by all processes within the same ``torch.distributed`` process group. The content is uninitialized. Parameters ---------- shape : tuple[int] The shape of the tensor. dtype : torch.dtype, optional The dtype of the tensor. Returns ------- Tensor The shared tensor. """returncall_once_and_share(lambda:torch.empty(*shape,dtype=dtype),shape,dtype)