oneflow.where¶
-
oneflow.where(condition, x=None, y=None)¶ Return a tensor of elements selected from either
xory, depending oncondition. If the element in condition is larger than 0,it will take the x element, else it will take the y element
Note
If
xis None andyis None, flow.where(condition) is identical to flow.nonzero(condition, as_tuple=True).The tensors
condition,x,ymust be broadcastable.- Parameters
- Returns
A tensor of shape equal to the broadcasted shape of
condition,x,y- Return type
For example:
>>> import numpy as np >>> import oneflow as flow >>> x = flow.tensor( ... np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), ... dtype=flow.float32, ... ) >>> y = flow.tensor(np.ones(shape=(3, 2)), dtype=flow.float32) >>> condition = flow.tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) >>> out = condition.where(x, y) >>> out tensor([[1.0000, 0.3139], ... [0.0478, 1.0000]], dtype=oneflow.float32)