[docs]defgather_mm(a,b,*,idx_b):r"""Gather data according to the given indices and perform matrix multiplication. Let the result tensor be ``c``, the operator conducts the following computation: c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b) Parameters ---------- a : Tensor A 2-D tensor of shape ``(N, D1)`` b : Tensor A 3-D tensor of shape ``(R, D1, D2)`` idx_b : Tensor, optional An 1-D integer tensor of shape ``(N,)``. Returns ------- Tensor The output dense matrix of shape ``(N, D2)`` """N,D1=F.shape(a)R,_,D2=F.shape(b)ifN>1000000orD1>8orD2>8:# Use segment_mm for large workloadimporttorchsorted_idx_b,perm=torch.sort(idx_b)_,rev_perm=torch.sort(perm)sorted_a=torch.index_select(a,0,perm)pos_l=torch.searchsorted(sorted_idx_b,torch.arange(R,device=a.device))pos_r=torch.cat([pos_l[1:],torch.tensor([len(idx_b)],device=a.device)])seglen=(pos_r-pos_l).cpu()# XXX(minjie): cause device synchronizereturntorch.index_select(F.segment_mm(sorted_a,b,seglen),0,rev_perm)else:returnF.gather_mm(a,b,None,idx_b)