oneflow.gather¶
-
oneflow.
gather
(input, dim, index, sparse_grad=False) → Tensor¶ Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
input
andindex
must have the same number of dimensions. It is also required thatindex.size(d) <= input.size(d)
for all dimensionsd != dim
.out
will have the same shape asindex
. Note thatinput
andindex
do not broadcast against each other.- Parameters
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
For example:
>>> import oneflow as flow >>> import numpy as np >>> input = np.random.randn(3, 4, 3, 5) >>> index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5)) >>> output = flow.gather(flow.Tensor(input), 1, flow.tensor(index, dtype=flow.int64)) >>> output.shape oneflow.Size([3, 4, 3, 5])