oneflow.nn.functional.gumbel_softmax

oneflow.nn.functional.gumbel_softmax(x: Tensor, dim: int, tau: float = 1.0, hard: bool = False)Tensor

Solve the problem that the output values of argmax do not reflect the probability distribution of the model’s output. Compensates for the fact that the argmax cannot participate in gradient back-propagation.

Gumbel is defined as:

\[Gumbel_i = -log(-log(U_i)),\ U_i \sim U(0,1)\]

Add Noise ~ Gumbel:

\[In = (In + Noise) / tau\]

Calculate Softmax value:

\[gumbel\_softmax(In)=\frac{e^{In_i/tau}}{\sum_{j=1}^n{e^{In_j/tau}}},i=1,2,3...n\]
Parameters
  • x (oneflow.Tensor) – the input Tensor.

  • dim (int, Tuple[int]) – the dimension to softmax.

  • tau (double) – the input tensor of Softmax should obey the Gumbel(x, tau).

  • hard (bool) – if hard=True, the output tensor will be one-hot.