"""DGL broadcast operator module."""
import operator
import torch
from .sparse_matrix import SparseMatrix, val_like
[docs]
def sp_broadcast_v(A: SparseMatrix, v: torch.Tensor, op: str) -> SparseMatrix:
    """Broadcast operator for sparse matrix and vector.
    :attr:`v` is broadcasted to the shape of :attr:`A` and then the operator is
    applied on the non-zero values of :attr:`A`.
    There are two cases regarding the shape of v:
    1. :attr:`v` is a vector of shape ``(1, A.shape[1])`` or ``(A.shape[1])``.
    In this case, :attr:`v` is broadcasted on the row dimension of :attr:`A`.
    2. :attr:`v` is a vector of shape ``(A.shape[0], 1)``. In this case,
    :attr:`v` is broadcasted on the column dimension of :attr:`A`.
    If ``A.val`` takes shape ``(nnz, D)``, then :attr:`v` will be broadcasted on
    the ``D`` dimension.
    Parameters
    ----------
    A: SparseMatrix
        Sparse matrix
    v: torch.Tensor
        Vector
    op: str
        Operator in ["add", "sub", "mul", "truediv"]
    Returns
    -------
    SparseMatrix
        Sparse matrix
    Examples
    --------
    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
    >>> val = torch.tensor([10, 20, 30])
    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
    >>> v = torch.tensor([1, 2, 3, 4])
    >>> dglsp.sp_broadcast_v(A, v, "add")
    SparseMatrix(indices=tensor([[1, 0, 2],
                                 [0, 3, 2]]),
                 values=tensor([11, 24, 33]),
                 shape=(3, 4), nnz=3)
    >>> v = torch.tensor([1, 2, 3]).view(-1, 1)
    >>> dglsp.sp_broadcast_v(A, v, "add")
    SparseMatrix(indices=tensor([[1, 0, 2],
                                 [0, 3, 2]]),
                 values=tensor([12, 21, 33]),
                 shape=(3, 4), nnz=3)
    >>> indices = torch.tensor([[1, 0, 2], [0, 3, 2]])
    >>> val = torch.tensor([[10, 20], [30, 40], [50, 60]])
    >>> A = dglsp.spmatrix(indices, val, shape=(3, 4))
    >>> v = torch.tensor([1, 2, 3]).view(-1, 1)
    >>> dglsp.sp_broadcast_v(A, v, "sub")
    SparseMatrix(indices=tensor([[1, 0, 2],
                                 [0, 3, 2]]),
                 values=tensor([[ 8, 18],
                                [29, 39],
                                [47, 57]]),
                 shape=(3, 4), nnz=3, val_size=(2,))
    """
    op = getattr(operator, op)
    if v.dim() == 1:
        v = v.view(1, -1)
    shape_error_message = (
        f"Dimension mismatch for broadcasting. Got A.shape = {A.shape} and"
        f"v.shape = {v.shape}."
    )
    assert v.dim() <= 2 and (1 in v.shape), shape_error_message
    broadcast_dim = None
    # v can be broadcasted to A if exactly one dimension of v is 1 and the other
    # is the same as A.
    for d, (dim1, dim2) in enumerate(zip(A.shape, v.shape)):
        assert dim2 in (1, dim1), shape_error_message
        if dim1 != dim2:
            assert broadcast_dim is None, shape_error_message
            broadcast_dim = d
    # A and v has the same shape of (1, *) or (*, 1).
    if broadcast_dim is None:
        broadcast_dim = 0 if A.shape[0] == 1 else 1
    if broadcast_dim == 0:
        v = v.view(-1)[A.col]
    else:
        v = v.view(-1)[A.row]
    if A.val.dim() > 1:
        v = v.view(-1, 1)
    ret_val = op(A.val, v)
    return val_like(A, ret_val) 
[docs]
def sp_add_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
    """Broadcast addition for sparse matrix and vector.
    See the definition of :func:`sp_broadcast_v` for details.
    """
    return sp_broadcast_v(A, v, "add") 
[docs]
def sp_sub_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
    """Broadcast substraction for sparse matrix and vector.
    See the definition of :func:`sp_broadcast_v` for details.
    """
    return sp_broadcast_v(A, v, "sub") 
[docs]
def sp_mul_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
    """Broadcast multiply for sparse matrix and vector.
    See the definition of :func:`sp_broadcast_v` for details.
    """
    return sp_broadcast_v(A, v, "mul") 
[docs]
def sp_div_v(A: SparseMatrix, v: torch.Tensor) -> SparseMatrix:
    """Broadcast division for sparse matrix and vector.
    See the definition of :func:`sp_broadcast_v` for details.
    """
    return sp_broadcast_v(A, v, "truediv")