oneflow.where¶
-
oneflow.
where
(condition, x=None, y=None)¶ Return a tensor of elements selected from either
x
ory
, depending oncondition
. If the element in condition is larger than 0,it will take the x element, else it will take the y element
Note
If
x
is None andy
is None, flow.where(condition) is identical to flow.nonzero(condition, as_tuple=True).The tensors
condition
,x
,y
must be broadcastable.- Parameters
- Returns
A tensor of shape equal to the broadcasted shape of
condition
,x
,y
- Return type
For example:
>>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor( ... np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), ... dtype=flow.float32, ... ) >>> y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32) >>> condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) >>> out = condition.where(x, y) >>> out tensor([[1.0000, 0.3139], ... [0.0478, 1.0000]], dtype=oneflow.float32)