"""Matmul ops for SparseMatrix"""# pylint: disable=invalid-namefromtypingimportUnionimporttorchfrom.sparse_matriximportSparseMatrix__all__=["spmm","bspmm","spspmm","matmul"]
[docs]defspmm(A:SparseMatrix,X:torch.Tensor)->torch.Tensor:"""Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``. Parameters ---------- A : SparseMatrix Sparse matrix of shape ``(L, M)`` with scalar values X : torch.Tensor Dense matrix of shape ``(M, N)`` or ``(M)`` Returns ------- torch.Tensor The dense matrix of shape ``(L, N)`` or ``(L)`` Examples -------- >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val = torch.randn(indices.shape[1]) >>> A = dglsp.spmatrix(indices, val) >>> X = torch.randn(2, 3) >>> result = dglsp.spmm(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([2, 3]) """assertisinstance(A,SparseMatrix),f"Expect arg1 to be a SparseMatrix object, got {type(A)}."assertisinstance(X,torch.Tensor),f"Expect arg2 to be a torch.Tensor, got {type(X)}."returntorch.ops.dgl_sparse.spmm(A.c_sparse_matrix,X)
[docs]defbspmm(A:SparseMatrix,X:torch.Tensor)->torch.Tensor:"""Multiplies a sparse matrix by a dense matrix by batches, equivalent to ``A @ X``. Parameters ---------- A : SparseMatrix Sparse matrix of shape ``(L, M)`` with vector values of length ``K`` X : torch.Tensor Dense matrix of shape ``(M, N, K)`` Returns ------- torch.Tensor Dense matrix of shape ``(L, N, K)`` Examples -------- >>> indices = torch.tensor([[0, 1, 1], [1, 0, 2]]) >>> val = torch.randn(len(row), 2) >>> A = dglsp.spmatrix(indices, val, shape=(3, 3)) >>> X = torch.randn(3, 3, 2) >>> result = dglsp.bspmm(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([3, 3, 2]) """assertisinstance(A,SparseMatrix),f"Expect arg1 to be a SparseMatrix object, got {type(A)}."assertisinstance(X,torch.Tensor),f"Expect arg2 to be a torch.Tensor, got {type(X)}."returnspmm(A,X)
[docs]defspspmm(A:SparseMatrix,B:SparseMatrix)->SparseMatrix:"""Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``. The non-zero values of the two sparse matrices must be 1D. Parameters ---------- A : SparseMatrix Sparse matrix of shape ``(L, M)`` B : SparseMatrix Sparse matrix of shape ``(M, N)`` Returns ------- SparseMatrix Sparse matrix of shape ``(L, N)``. Examples -------- >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val1 = torch.ones(len(row1)) >>> A = dglsp.spmatrix(indices1, val1) >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]]) >>> val2 = torch.ones(len(row2)) >>> B = dglsp.spmatrix(indices2, val2) >>> dglsp.spspmm(A, B) SparseMatrix(indices=tensor([[0, 0, 1, 1, 1], [1, 2, 0, 1, 2]]), values=tensor([1., 1., 1., 1., 1.]), shape=(2, 3), nnz=5) """assertisinstance(A,SparseMatrix),f"Expect A1 to be a SparseMatrix object, got {type(A)}."assertisinstance(B,SparseMatrix),f"Expect A2 to be a SparseMatrix object, got {type(B)}."returnSparseMatrix(torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix,B.c_sparse_matrix))
[docs]defmatmul(A:Union[torch.Tensor,SparseMatrix],B:Union[torch.Tensor,SparseMatrix])->Union[torch.Tensor,SparseMatrix]:"""Multiplies two dense/sparse matrices, equivalent to ``A @ B``. This function does not support the case where :attr:`A` is a \ ``torch.Tensor`` and :attr:`B` is a ``SparseMatrix``. * If both matrices are torch.Tensor, it calls \ :func:`torch.matmul()`. The result is a dense matrix. * If both matrices are sparse, it calls :func:`dgl.sparse.spspmm`. The \ result is a sparse matrix. * If :attr:`A` is sparse while :attr:`B` is dense, it calls \ :func:`dgl.sparse.spmm`. The result is a dense matrix. * The operator supports batched sparse-dense matrix multiplication. In \ this case, the sparse matrix :attr:`A` should have shape ``(L, M)``, \ where the non-zero values have a batch dimension ``K``. The dense \ matrix :attr:`B` should have shape ``(M, N, K)``. The output \ is a dense matrix of shape ``(L, N, K)``. * Sparse-sparse matrix multiplication does not support batched computation. Parameters ---------- A : torch.Tensor or SparseMatrix The first matrix. B : torch.Tensor or SparseMatrix The second matrix. Returns ------- torch.Tensor or SparseMatrix The result matrix Examples -------- Multiplies a diagonal matrix with a dense matrix. >>> val = torch.randn(3) >>> A = dglsp.diag(val) >>> B = torch.randn(3, 2) >>> result = dglsp.matmul(A, B) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([3, 2]) Multiplies a sparse matrix with a dense matrix. >>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val = torch.randn(indices.shape[1]) >>> A = dglsp.spmatrix(indices, val) >>> X = torch.randn(2, 3) >>> result = dglsp.matmul(A, X) >>> type(result) <class 'torch.Tensor'> >>> result.shape torch.Size([2, 3]) Multiplies a sparse matrix with a sparse matrix. >>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) >>> val1 = torch.ones(indices1.shape[1]) >>> A = dglsp.spmatrix(indices1, val1) >>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]]) >>> val2 = torch.ones(indices2.shape[1]) >>> B = dglsp.spmatrix(indices2, val2) >>> result = dglsp.matmul(A, B) >>> type(result) <class 'dgl.sparse.sparse_matrix.SparseMatrix'> >>> result.shape (2, 3) """assertisinstance(A,(torch.Tensor,SparseMatrix)),f"Expect arg1 to be a torch.Tensor or SparseMatrix, got {type(A)}."assertisinstance(B,(torch.Tensor,SparseMatrix)),(f"Expect arg2 to be a torch Tensor or SparseMatrix"f"object, got {type(B)}.")ifisinstance(A,torch.Tensor)andisinstance(B,torch.Tensor):returntorch.matmul(A,B)assertnotisinstance(A,torch.Tensor),(f"Expect arg2 to be a torch Tensor if arg 1 is torch Tensor, "f"got {type(B)}.")ifisinstance(B,torch.Tensor):returnspmm(A,B)returnspspmm(A,B)