oneflow.batch_gather

oneflow.batch_gather()

Gather the element in batch dims.

Parameters
  • in (Tensor) – the input tensor.

  • indices (Tensor) – the indices tensor, its dtype must be int32/64.

For example:

Example 1:

>>> import oneflow as flow
>>> import numpy as np

>>> x = flow.Tensor(np.array([[1, 2, 3],
...                           [4, 5, 6]]))
>>> indices = flow.tensor(np.array([1, 0]).astype(np.int64))
>>> out = flow.batch_gather(x, indices)

tensor([[4., 5., 6.],
        [1., 2., 3.]], dtype=oneflow.float32)

Example 2:

>>> import oneflow as flow
>>> import numpy as np

>>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]],
...                           [[1, 2, 3], [4, 5, 6]]]))
>>> indices = flow.tensor(np.array([[1, 0],
...                                 [0, 1]]).astype(np.int64))
>>> out = flow.batch_gather(x, indices)

tensor([[[4., 5., 6.],
         [1., 2., 3.]],
        [[1., 2., 3.],
         [4., 5., 6.]]], dtype=oneflow.float32)