oneflow.nn.graph.graph_config.GraphConfig.enable_zero

GraphConfig.enable_zero(mode: bool = True, *, stage: int = 2, shard_min_size: int = 1024, shard_restore_level: int = 1)

Enable ZeRO redundancy optimizer.

This optimization will reduce optimizer states memory consumption as described by ZeRO https://arxiv.org/abs/1910.02054 .

The default zero stage is 2.

For example:

import oneflow as flow

class Graph(flow.nn.Graph):
    def __init__(self):
        super().__init__()
        self.linear = flow.nn.Linear(3, 8, False)
        self.config.enable_zero()
    def build(self, x):
        return self.linear(x)

graph = Graph()
Parameters
  • mode (bool) – if set to true, optimizer states of Data Parallel will be sharded across devices.

  • stage (int) – optimization stage, range from 1 to 3.

  • shard_min_size (int) – min size (element count) of a shard of an optimizer state.

  • shard_restore_level (int) – level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this parameter is at pre-alpha stage.