GraphConfig.enable_zero(mode: bool = True, *, stage: int = 2, shard_min_size: int = 1024, shard_restore_level: int = 1)

Enable ZeRO redundancy optimizer.

This optimzation will reduce optimizer states memory consumption as described by ZeRO .

The default zero stage is 2.

For example:

import oneflow as flow

class Graph(flow.nn.Graph):
    def __init__(self):
        self.linear = flow.nn.Linear(3, 8, False)
    def build(self, x):
        return self.linear(x)

graph = Graph()
  • mode (bool) – if set to true, optimizer states of Data Parallel will be sharded across devices.

  • stage (int) – optimization stage, range from 1 to 3.

  • shard_min_size (int) – min size (element count) of a shard of an optimizer state.

  • shard_restore_level (int) – level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage.