oneflow.where

oneflow.where(condition, x=None, y=None)

Return a tensor of elements selected from either x or y, depending on condition. If the element in condition is larger than 0,

it will take the x element, else it will take the y element

Note

If x is None and y is None, flow.where(condition) is identical to flow.nonzero(condition, as_tuple=True).

The tensors condition, x, y must be broadcastable.

Parameters
  • condition (IntTensor) – When 1 (nonzero), yield x, otherwise yield y

  • x (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is True

  • y (Tensor or Scalar) – value (if :attr:x is a scalar) or values selected at indices where condition is False

Returns

A tensor of shape equal to the broadcasted shape of condition, x, y

Return type

Tensor

For example:

>>> import numpy as np
>>> import oneflow as flow
>>> x = flow.tensor(
...    np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]),
...    dtype=flow.float32,
... )
>>> y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32)
>>> condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32)
>>> out = condition.where(x, y)
>>> out 
tensor([[1.0000, 0.3139],
        ...
        [0.0478, 1.0000]], dtype=oneflow.float32)