oneflow.nn.graph.graph_config.GraphConfig.enable_zero¶
-
GraphConfig.
enable_zero
(mode: bool = True, *, stage: int = 2, shard_min_size: int = 1024, shard_restore_level: int = 1)¶ Enable ZeRO redundancy optimizer.
This optimization will reduce optimizer states memory consumption as described by ZeRO https://arxiv.org/abs/1910.02054 .
The default zero stage is 2.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_zero() def build(self, x): return self.linear(x) graph = Graph()
- Parameters
mode (bool) – if set to true, optimizer states of Data Parallel will be sharded across devices.
stage (int) – optimization stage, range from 1 to 3.
shard_min_size (int) – min size (element count) of a shard of an optimizer state.
shard_restore_level (int) – level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this parameter is at pre-alpha stage.