oneflow.nn.graph.graph_config.GraphConfig.enable_amp

GraphConfig.enable_amp(mode: bool = True, *, dtype: oneflow._oneflow_internal.dtype = oneflow.float16)

If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training.

For example:

import oneflow as flow

class Graph(flow.nn.Graph):
    def __init__(self):
        super().__init__()
        self.linear = flow.nn.Linear(3, 8, False)
        self.config.enable_amp(True) # Use mixed precision mode.
    def build(self, x):
        return self.linear(x)

graph = Graph()
Parameters

mode (bool, optional) – The default value is True.