Base class for running neural networks in Graph Mode.

class oneflow.nn.Graph

Base class for training or evaluating a neural network in graph mode.

To use graph mode for model training or evaluation in OneFlow, you should:

  1. Define your customized graph as a subclass of nn.Graph.

  2. Add super().__init__() in your subclass’s __init__().

  3. Add modules to your graph as regular attributes.

  4. Define computation logical in build() method.

  5. Instantiate your graph then call it.

>>> import oneflow as flow

>>> class LinearGraph(flow.nn.Graph):
...    def __init__(self):
...        super().__init__()
...        # Add a module to the graph.
...        self.linear = flow.nn.Linear(3, 8, False)
...    def build(self, x):
...        # Use the module to build the computation logic of the graph.
...        return self.linear(x)

# Instantiate the graph
>>> linear_graph = LinearGraph()
>>> x = flow.randn(4, 3)

# First call on graph will run graph's build() method to
# trace a computatioin graph. Then the computation graph will be
# optimized and executed for the first time.
>>> linear_graph(x).shape
oneflow.Size([4, 8])

# Later call on graph will execute the computation graph directly.
>>> linear_graph(x).shape
oneflow.Size([4, 8])

Note that Graph cannot be nested at the moment.


Initializes internal Graph states. It MUST be called in __init__ method of subclass.

>>> import oneflow as flow
>>> class SubclassGraph(flow.nn.Graph):
...     def __init__(self):
...         super().__init__() # MUST be called
...         # Then define the graph attributes
...     def build(self):
...         pass

The build() method must be overridden to define neural network computaion logic.

The build() method of nn.Graph is very similar to the forward() method of nn.Module. It is used to describe the computatioin logical of a neural network.

When a graph object being called for the first time, the build() method will be called implicitly to build the computatioin graph.

Make sure to call modules’s train() or eval() method before the first call of your graph to make the module executing the right training or evaluation logic if needed.

>>> import oneflow as flow
>>> class MyGraph(flow.nn.Graph):
...     def __init__(self):
...         super().__init__()
...         self.linear = flow.nn.Linear(3, 8, False)
...     def build(self, x):
...         return self.linear(x)

>>> linear_graph = MyGraph()
>>> x = flow.randn(4, 3)
>>> y = linear_graph(x) # The build() method is called implicitly

Note that build() method’s inputs and outputs only accept positional arguements at the moment, each argument must be one of these types:

  • Tensor

  • list of Tensor

  • None

add_optimizer(optim: oneflow.nn.optimizer.optimizer.Optimizer, *, lr_sch: Optional[oneflow.nn.optimizer.lr_scheduler.LrScheduler] = None)

Add an optimizer, an learning rate scheduler to the graph.

To do training with nn.Graph, you should do 2 more things:

  1. Add at least one optimizer(learning rate schedulers are optional) with add_optimizer() method.

  2. Call loss tensor’s backward() method in build() method.

Note that the computaion graph will automatically execute these methods:

  • optimizer’s clip_grad() if a optimizer is set to do grad cliping.

  • optimizer’s step().

  • optimizer’s zero_grad().

  • learn rate scheduler’s step().

Also note that only scalar tensor are allowed to call backward() in for the moment. So you may call Tensor.sum() or Tensor.mean() to make the loss tensor a scalar tensor.

>>> import oneflow as flow
>>> loss_fn = flow.nn.MSELoss(reduction="sum")
>>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1))
>>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6)
>>> class LinearTrainGraph(flow.nn.Graph):
...     def __init__(self):
...         super().__init__()
...         self.model = model
...         self.loss_fn = loss_fn
...         # Add an optimizer
...         self.add_optimizer(optimizer)
...     def build(self, x, y):
...         y_pred = self.model(x)
...         loss = self.loss_fn(y_pred, y)
...         # Call loss tensor's backward(), loss tensor must be a scalar tensor
...         loss.backward()
...         return loss

>>> linear_graph = LinearTrainGraph()
>>> x = flow.randn(10, 3)
>>> y = flow.randn(10)
>>> for t in range(3):
...     loss = linear_graph(x, y)
  • optim (oneflow.optim.Optimizer) – The optimizer.

  • lr_sch – The learning rate scheduler, see oneflow.optim.lr_scheduler.

set_grad_scaler(grad_scaler: Optional[oneflow.amp.grad_scaler.GradScaler] = None)

Set the GradScaler for gradient and loss scaling.


Call nn.Graph subclass instance to run your customized graph.

Call your customized graph after the instantiation:

g = CustomGraph()
out_tensors = g(input_tensors)

The inputs of __call__ method must match the inputs of build() method. And the __call__ method will return outputs matching the outputs of build() method.

Note that the first call takes longer than later calls, because nn.Graph will do the computaion graph generation and optimization at the first call.

Donot override this function.

property name

Name auto-generated for this graph.

property training

In traninig mode if the graph has an optimizer.

debug(v_level: int = 0, ranks: Optional[Union[int, List[int]]] = None, mode: bool = True)None

Open or close debug mode of the graph.

If in debug mode, logs of computation graph building infos or warnings will be printed. Otherwise, only errors will be printed.

Use v_level to choose verbose debug info level, default level is 0, max level is 1. v_level 0 will print warning and graph creating stages. v_level 1 will additionally print graph build info of each module.

Use ranks to choose which rank to print the debug information.

g = CustomGraph()
g.debug()  # Open debug mode
out_tensors = g(input_tensors)  # Will print log for debug at the first call
  • v_level (int) – choose verbose debug info level, default v_level is 0, max v_level is 1.

  • ranks (int or list(int)) – choose ranks to print the debug information. Default rank 0. You can choose any valid rank. Ranks equals -1 means debug on all ranks.

  • mode (bool) – whether to set debug mode (True) or not (False). Default: True.


For printing the graph structure.

The graph structure can be printed after graph instantiation.

After the first call of graph, inputs and outputs will be added to the graph structure.

g = CustomGraph()

out_tensors = g(input_tensors)
print(g) # Inputs and Outputs infos are added