oneflow.nn.graph.graph_config.GraphConfig.enable_amp¶
-
GraphConfig.
enable_amp
(mode: bool = True, *, dtype: oneflow._oneflow_internal.dtype = oneflow.float16)¶ If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_amp(True) # Use mixed precision mode. def build(self, x): return self.linear(x) graph = Graph()
- Parameters
mode (bool, optional) – The default value is True.