oneflow.scatter

oneflow.scatter(input, dim, index, src)

This operator writes the elements specified by index along with the axis dim from the src into the input.

Take a 3-D blob as example, the output is specified by:

input[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
input[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
input[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

input, index and src (if it is a Tensor) should all have the same number of dimensions. It is also required that index.shape(d) <= src.shape(d) for all dimensions d, and that index.shape(d) <= self.shape(d) for all dimensions d != dim. Note that index and src do not broadcast.

Parameters
  • input (Tensor) – The input blob.

  • dim (int) – The axis along which to index

  • index (Tensor) – The index blob of elements to scatter.

  • src (Tensor or float) – The source blob whose elements will be scatterd and updated to output.

Returns

The scatterd Tensor.

Return type

Tensor

For example:

>>> import oneflow as flow
>>> import numpy as np

>>> input = flow.ones((3,5))*2
>>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)
>>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))
>>> out = flow.scatter(input, 1, index, src)
>>> out
tensor([[ 0., 10., 20.,  2.,  2.],
        [50., 60.,  2.,  2., 70.],
        [ 2.,  2.,  2.,  2.,  2.]], dtype=oneflow.float32)