oneflow.gather

oneflow.gather(input, dim, index, sparse_grad=False)Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d) for all dimensions d != dim. out will have the same shape as index. Note that input and index do not broadcast against each other.

Parameters
  • input (Tensor) – the source tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to gather

For example:

>>> import oneflow as flow
>>> import numpy as np
>>> input = np.random.randn(3, 4, 3, 5)
>>> index = np.random.choice(np.arange(3), size=180, replace=True).reshape((3, 4, 3, 5))
>>> output = flow.gather(flow.Tensor(input), 1, flow.tensor(index, dtype=flow.int64))
>>> output.shape
oneflow.Size([3, 4, 3, 5])