oneflow.nn.Graph.build

Graph.build(*args, **kwargs)

The build() method must be overridden to define neural network computaion logic.

The build() method of nn.Graph is very similar to the forward() method of nn.Module. It is used to describe the computatioin logical of a neural network.

When a graph object being called for the first time, the build() method will be called implicitly to build the computatioin graph.

Make sure to call modules’s train() or eval() method before the first call of your graph to make the module executing the right training or evaluation logic if needed.

For example:

>>> import oneflow as flow
>>> linear = flow.nn.Linear(3, 8, False)
>>> class MyGraph(flow.nn.Graph):
...     def __init__(self):
...         super().__init__()
...         self.model = linear
...     def build(self, x):
...         return self.model(x)

>>> linear_graph = MyGraph()
>>> x = flow.randn(4, 3)
>>> linear.eval() # make linear module executing in evaluation mode
Linear(in_features=3, out_features=8, bias=False)
>>> y = linear_graph(x) # The build() method is called implicitly

Note

build() method’s inputs and outputs support list/tuple/dict, but the item in them must be one of these types:

  • Tensor

  • None