dgl.ops.gather_mm๏
- dgl.ops.gather_mm(a, b, *, idx_b)[source]๏
Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be
c
, the operator conducts the following computation:c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b)
- Parameters:
a (Tensor) โ A 2-D tensor of shape
(N, D1)
b (Tensor) โ A 3-D tensor of shape
(R, D1, D2)
idx_b (Tensor, optional) โ An 1-D integer tensor of shape
(N,)
.
- Returns:
The output dense matrix of shape
(N, D2)
- Return type:
Tensor