oneflow.argmax

oneflow.argmax()

The op computes the index with the largest value of a Tensor at specified axis.

Parameters
  • input (oneflow.Tensor) – Input Tensor

  • dim (int, optional) – dimension to be calculated. Defaults to the last dim (-1)

  • keepdim (bool optional) – whether the output tensor has dim retained or not. Ignored if dim=None.

Returns

A Tensor(dtype=int64) contains the index with the largest value of input

Return type

oneflow.Tensor

For example:

>>> import oneflow as flow

>>> input = flow.tensor([[1, 3, 8, 7, 2],
...            [1, 9, 4, 3, 2]], dtype=flow.float32)
>>> output = flow.argmax(input)
>>> output
tensor(6, dtype=oneflow.int64)
>>> output = flow.argmax(input, dim=1)
>>> output
tensor([2, 1], dtype=oneflow.int64)