oneflow.batch_gather¶
-
oneflow.
batch_gather
()¶ Gather the element in batch dims.
- Parameters
For example:
Example 1:
>>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.array([[1, 2, 3], ... [4, 5, 6]])) >>> indices = flow.tensor(np.array([1, 0]).astype(np.int64)) >>> out = flow.batch_gather(x, indices) tensor([[4., 5., 6.], [1., 2., 3.]], dtype=oneflow.float32)
Example 2:
>>> import oneflow as flow >>> import numpy as np >>> x = flow.Tensor(np.array([[[1, 2, 3], [4, 5, 6]], ... [[1, 2, 3], [4, 5, 6]]])) >>> indices = flow.tensor(np.array([[1, 0], ... [0, 1]]).astype(np.int64)) >>> out = flow.batch_gather(x, indices) tensor([[[4., 5., 6.], [1., 2., 3.]], [[1., 2., 3.], [4., 5., 6.]]], dtype=oneflow.float32)