oneflow.nn.Graph.build¶
-
Graph.
build
(*args, **kwargs)¶ The
build()
method must be overridden to define neural network computaion logic.The
build()
method of nn.Graph is very similar to theforward()
method of nn.Module. It is used to describe the computatioin logical of a neural network.When a graph object being called for the first time, the
build()
method will be called implicitly to build the computatioin graph.Make sure to call modules’s
train()
oreval()
method before the first call of your graph to make the module executing the right training or evaluation logic if needed.For example:
>>> import oneflow as flow >>> linear = flow.nn.Linear(3, 8, False) >>> class MyGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = linear ... def build(self, x): ... return self.model(x) >>> linear_graph = MyGraph() >>> x = flow.randn(4, 3) >>> linear.eval() # make linear module executing in evaluation mode Linear(in_features=3, out_features=8, bias=False) >>> y = linear_graph(x) # The build() method is called implicitly
Note
build()
method’s inputs and outputs support list/tuple/dict, but the item in them must be one of these types:Tensor
None