"""Torch NodeEmbedding."""
from datetime import timedelta
import torch as th
from ...backend import pytorch as F
from ...cuda import nccl
from ...partition import NDArrayPartition
from ...utils import create_shared_mem_array, get_shared_mem_array
_STORE = None
[docs]
class NodeEmbedding:  # NodeEmbedding
    """Class for storing node embeddings.
    The class is optimized for training large-scale node embeddings. It updates the embedding in
    a sparse way and can scale to graphs with millions of nodes. It also supports partitioning
    to multiple GPUs (on a single machine) for more acceleration. It does not support partitioning
    across machines.
    Currently, DGL provides two optimizers that work with this NodeEmbedding
    class: ``SparseAdagrad`` and ``SparseAdam``.
    The implementation is based on torch.distributed package. It depends on the pytorch
    default distributed process group to collect multi-process information and uses
    ``torch.distributed.TCPStore`` to share meta-data information across multiple gpu processes.
    It use the local address of '127.0.0.1:12346' to initialize the TCPStore.
    NOTE: The support of NodeEmbedding is experimental.
    Parameters
    ----------
    num_embeddings : int
        The number of embeddings. Currently, the number of embeddings has to be the same as
        the number of nodes.
    embedding_dim : int
        The dimension size of embeddings.
    name : str
        The name of the embeddings. The name should uniquely identify the embeddings in the system.
    init_func : callable, optional
        The function to create the initial data. If the init function is not provided,
        the values of the embeddings are initialized to zero.
    device : th.device
        Device to store the embeddings on.
    parittion : NDArrayPartition
        The partition to use to distributed the embeddings between
        processes.
    Examples
    --------
    Before launching multiple gpu processes
    >>> def initializer(emb):
            th.nn.init.xavier_uniform_(emb)
            return emb
    In each training process
    >>> emb = dgl.nn.NodeEmbedding(g.num_nodes(), 10, 'emb', init_func=initializer)
    >>> optimizer = dgl.optim.SparseAdam([emb], lr=0.001)
    >>> for blocks in dataloader:
    ...     ...
    ...     feats = emb(nids, gpu_0)
    ...     loss = F.sum(feats + 1, 0)
    ...     loss.backward()
    ...     optimizer.step()
    """
    def __init__(
        self,
        num_embeddings,
        embedding_dim,
        name,
        init_func=None,
        device=None,
        partition=None,
    ):
        global _STORE
        if device is None:
            device = th.device("cpu")
        # Check whether it is multi-gpu training or not.
        if th.distributed.is_initialized():
            rank = th.distributed.get_rank()
            world_size = th.distributed.get_world_size()
        else:
            rank = -1
            world_size = 0
        self._rank = rank
        self._world_size = world_size
        self._store = None
        self._comm = None
        self._partition = partition
        host_name = "127.0.0.1"
        port = 12346
        if rank >= 0:
            # for multi-gpu training, setup a TCPStore for
            # embeding status synchronization across GPU processes
            if _STORE is None:
                _STORE = th.distributed.TCPStore(
                    host_name,
                    port,
                    world_size,
                    rank == 0,
                    timedelta(seconds=10 * 60),
                )
            self._store = _STORE
        # embeddings is stored in CPU memory.
        if th.device(device) == th.device("cpu"):
            if rank <= 0:
                emb = create_shared_mem_array(
                    name, (num_embeddings, embedding_dim), th.float32
                )
                if init_func is not None:
                    emb = init_func(emb)
            if rank == 0:  # the master gpu process
                for _ in range(1, world_size):
                    # send embs
                    self._store.set(name, name)
            elif rank > 0:
                # receive
                self._store.wait([name])
                emb = get_shared_mem_array(
                    name, (num_embeddings, embedding_dim), th.float32
                )
            self._tensor = emb
        else:  # embeddings is stored in GPU memory.
            self._comm = True
            if not self._partition:
                # for communication we need a partition
                self._partition = NDArrayPartition(
                    num_embeddings,
                    self._world_size if self._world_size > 0 else 1,
                    mode="remainder",
                )
            # create local tensors for the weights
            local_size = self._partition.local_size(max(self._rank, 0))
            # TODO(dlasalle): support 16-bit/half embeddings
            emb = th.empty(
                [local_size, embedding_dim],
                dtype=th.float32,
                requires_grad=False,
                device=device,
            )
            if init_func:
                emb = init_func(emb)
            self._tensor = emb
        self._num_embeddings = num_embeddings
        self._embedding_dim = embedding_dim
        self._name = name
        self._optm_state = None  # track optimizer state
        self._trace = []  # track minibatch
    def __call__(self, node_ids, device=th.device("cpu")):
        """
        node_ids : th.tensor
            Index of the embeddings to collect.
        device : th.device
            Target device to put the collected embeddings.
        """
        if not self._comm:
            # For embeddings stored on the CPU.
            emb = self._tensor[node_ids].to(device)
        else:
            # For embeddings stored on the GPU.
            # The following method is designed to perform communication
            # across multiple GPUs and can handle situations where only one GPU
            # is present gracefully, a.k.a. self._world_size == 1 or
            # 0 (when th.distributed.is_initialized() is false).
            emb = nccl.sparse_all_to_all_pull(
                node_ids, self._tensor, self._partition
            )
            emb = emb.to(device)
        if F.is_recording():
            emb = F.attach_grad(emb)
            self._trace.append((node_ids.to(device), emb))
        return emb
    @property
    def store(self):
        """Return torch.distributed.TCPStore for
        meta data sharing across processes.
        Returns
        -------
        torch.distributed.TCPStore
            KVStore used for meta data sharing.
        """
        return self._store
    @property
    def partition(self):
        """Return the partition identifying how the tensor is split across
        processes.
        Returns
        -------
        String
            The mode.
        """
        return self._partition
    @property
    def rank(self):
        """Return rank of current process.
        Returns
        -------
        int
            The rank of current process.
        """
        return self._rank
    @property
    def world_size(self):
        """Return world size of the pytorch distributed training env.
        Returns
        -------
        int
            The world size of the pytorch distributed training env.
        """
        return self._world_size
    @property
    def name(self):
        """Return the name of NodeEmbedding.
        Returns
        -------
        str
            The name of NodeEmbedding.
        """
        return self._name
    @property
    def num_embeddings(self):
        """Return the number of embeddings.
        Returns
        -------
        int
            The number of embeddings.
        """
        return self._num_embeddings
    @property
    def embedding_dim(self):
        """Return the dimension of embeddings.
        Returns
        -------
        int
            The dimension of embeddings.
        """
        return self._embedding_dim
    def set_optm_state(self, state):
        """Store the optimizer related state tensor.
        Parameters
        ----------
        state : tuple of torch.Tensor
            Optimizer related state.
        """
        self._optm_state = state
    @property
    def optm_state(self):
        """Return the optimizer related state tensor.
        Returns
        -------
        tuple of torch.Tensor
            The optimizer related state.
        """
        return self._optm_state
    @property
    def trace(self):
        """Return a trace of the indices of embeddings
        used in the training step(s).
        Returns
        -------
        [torch.Tensor]
            The indices of embeddings used in the training step(s).
        """
        return self._trace
    def reset_trace(self):
        """Clean up the trace of the indices of embeddings
        used in the training step(s).
        """
        self._trace = []
    @property
    def weight(self):
        """Return the tensor storing the node embeddings
        Returns
        -------
        torch.Tensor
            The tensor storing the node embeddings
        """
        return self._tensor
    def all_set_embedding(self, values):
        """Set the values of the embedding. This method must be called by all
        processes sharing the embedding with identical tensors for
        :attr:`values`.
        NOTE: This method must be called by all processes sharing the
        embedding, or it may result in a deadlock.
        Parameters
        ----------
        values : Tensor
            The global tensor to pull values from.
        """
        if self._partition:
            idxs = F.copy_to(
                self._partition.get_local_indices(
                    max(self._rank, 0),
                    ctx=F.context(self._tensor),
                ),
                F.context(values),
            )
            self._tensor[:] = F.copy_to(
                F.gather_row(values, idxs), ctx=F.context(self._tensor)
            )[:]
        else:
            if self._rank == 0:
                self._tensor[:] = F.copy_to(
                    values, ctx=F.context(self._tensor)
                )[:]
        if th.distributed.is_initialized():
            th.distributed.barrier()
    def _all_get_tensor(self, shared_name, tensor, shape):
        """A helper function to get model-parallel tensors.
        This method must and only need to be called in multi-GPU DDP training.
        For now, it's only used in ``all_get_embedding`` and
        ``_all_get_optm_state``.
        """
        # create a shared memory tensor
        if self._rank == 0:
            # root process creates shared memory
            val = create_shared_mem_array(
                shared_name,
                shape,
                tensor.dtype,
            )
            self._store.set(shared_name, shared_name)
        else:
            self._store.wait([shared_name])
            val = get_shared_mem_array(
                shared_name,
                shape,
                tensor.dtype,
            )
        # need to map indices and slice into existing tensor
        idxs = self._partition.map_to_global(
            F.arange(0, tensor.shape[0], ctx=F.context(tensor)),
            self._rank,
        ).to(val.device)
        val[idxs] = tensor.to(val.device)
        self._store.delete_key(shared_name)
        # wait for all processes to finish
        th.distributed.barrier()
        return val
    def all_get_embedding(self):
        """Return a copy of the embedding stored in CPU memory. If this is a
        multi-processing instance, the tensor will be returned in shared
        memory. If the embedding is currently stored on multiple GPUs, all
        processes must call this method in the same order.
        NOTE: This method must be called by all processes sharing the
        embedding, or it may result in a deadlock.
        Returns
        -------
        torch.Tensor
            The tensor storing the node embeddings.
        """
        if self._partition:
            if self._world_size == 0:
                # non-multiprocessing
                return self._tensor.to(th.device("cpu"))
            else:
                return self._all_get_tensor(
                    f"{self._name}_gather",
                    self._tensor,
                    (self._num_embeddings, self._embedding_dim),
                )
        else:
            # already stored in CPU memory
            return self._tensor
    def _all_get_optm_state(self):
        """Return a copy of the whole optimizer states stored in CPU memory.
        If this is a multi-processing instance, the states will be returned in
        shared memory. If the embedding is currently stored on multiple GPUs,
        all processes must call this method in the same order.
        NOTE: This method must be called by all processes sharing the
        embedding, or it may result in a deadlock.
        Returns
        -------
        tuple of torch.Tensor
            The optimizer states stored in CPU memory.
        """
        if self._partition:
            if self._world_size == 0:
                # non-multiprocessing
                return tuple(
                    state.to(th.device("cpu")) for state in self._optm_state
                )
            else:
                return tuple(
                    self._all_get_tensor(
                        f"state_gather_{self._name}_{i}",
                        state,
                        (self._num_embeddings, *state.shape[1:]),
                    )
                    for i, state in enumerate(self._optm_state)
                )
        else:
            # already stored in CPU memory
            return self._optm_state
    def _all_set_optm_state(self, states):
        """Set the optimizer states of the embedding. This method must be
        called by all processes sharing the embedding with identical
        :attr:`states`.
        NOTE: This method must be called by all processes sharing the
        embedding, or it may result in a deadlock.
        Parameters
        ----------
        states : tuple of torch.Tensor
            The global states to pull values from.
        """
        if self._partition:
            idxs = F.copy_to(
                self._partition.get_local_indices(
                    max(self._rank, 0), ctx=F.context(self._tensor)
                ),
                F.context(states[0]),
            )
            for state, new_state in zip(self._optm_state, states):
                state[:] = F.copy_to(
                    F.gather_row(new_state, idxs), ctx=F.context(self._tensor)
                )[:]
        else:
            # stored in CPU memory
            if self._rank <= 0:
                for state, new_state in zip(self._optm_state, states):
                    state[:] = F.copy_to(
                        new_state, ctx=F.context(self._tensor)
                    )[:]
        if th.distributed.is_initialized():
            th.distributed.barrier()