oneflow.nn.Graph¶
Base class for running neural networks in Static Graph Mode.¶
-
class
oneflow.nn.
Graph
¶ Base class for training or evaluating a neural network in static graph mode.
To use static graph mode for model training or evaluation in OneFlow, you should:
Define your customized graph as a subclass of
nn.Graph
.Add
super().__init__()
in your subclass’s__init__()
.Add modules to your graph as regular attributes.
Define computation logical in
build()
method.Instantiate your graph then call it.
For example:
>>> import oneflow as flow >>> class LinearGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... # Add a module to the graph. ... self.linear = flow.nn.Linear(3, 8, False) ... def build(self, x): ... # Use the module to build the computation logic of the graph. ... return self.linear(x) # Instantiate the graph >>> linear_graph = LinearGraph() >>> x = flow.randn(4, 3) # First call on graph will run graph's build() method to # trace a computatioin graph. Then the computation graph will be # optimized and executed for the first time. >>> linear_graph(x).shape oneflow.Size([4, 8]) # Later call on graph will execute the computation graph directly. >>> linear_graph(x).shape oneflow.Size([4, 8])
Note
nn.Graph cannot be nested at the moment.
-
__init__
()¶ Initializes internal Graph states. It MUST be called in
__init__
method of subclass.For example:
>>> import oneflow as flow >>> class SubclassGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() # MUST be called ... # Then define the graph attributes ... def build(self): ... pass
-
build
(*args, **kwargs)¶ The
build()
method must be overridden to define neural network computaion logic.The
build()
method of nn.Graph is very similar to theforward()
method of nn.Module. It is used to describe the computatioin logical of a neural network.When a graph object being called for the first time, the
build()
method will be called implicitly to build the computatioin graph.Make sure to call modules’s
train()
oreval()
method before the first call of your graph to make the module executing the right training or evaluation logic if needed.For example:
>>> import oneflow as flow >>> linear = flow.nn.Linear(3, 8, False) >>> class MyGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = linear ... def build(self, x): ... return self.model(x) >>> linear_graph = MyGraph() >>> x = flow.randn(4, 3) >>> linear.eval() # make linear module executing in evaluation mode Linear(in_features=3, out_features=8, bias=False) >>> y = linear_graph(x) # The build() method is called implicitly
Note
build()
method’s inputs and outputs support list/tuple/dict, but the item in them must be one of these types:Tensor
None
-
__call__
(*args, **kwargs)¶ Call nn.Graph subclass instance to run your customized graph.
Call your customized graph after the instantiation:
For example:
g = CustomGraph() out_tensors = g(input_tensors)
The inputs of
__call__
method must match the inputs ofbuild()
method. And the__call__
method will return outputs matching the outputs ofbuild()
method.Note
The first call takes longer than later calls, because nn.Graph will do the computaion graph generation and optimization at the first call.
Donot override this function.
-
add_optimizer
(optim: oneflow.nn.optimizer.optimizer.Optimizer, *, lr_sch: Optional[oneflow.nn.optimizer.lr_scheduler.LRScheduler] = None, is_sparse: bool = False)¶ Add an optimizer, an learning rate scheduler to the graph.
To do training with nn.Graph, you should do 2 more things:
Add at least one optimizer(learning rate schedulers are optional) with
add_optimizer()
method.Call loss tensor’s
backward()
method inbuild()
method.
Note that the computaion graph will automatically execute these methods:
optimizer’s
clip_grad()
if a optimizer is set to do grad cliping.optimizer’s
step()
.optimizer’s
zero_grad()
.learn rate scheduler’s
step()
.
Also note that only scalar tensor are allowed to call
backward()
innn.Graph.build()
for the moment. So you may call methods such asTensor.mean()
to make the loss tensor a scalar tensor.Note
If you want to output the learning rate information for each step, set the
verbose
parameter of thelr_scheduler
toTrue
, and you will see the result at rank 0.This feature is the same as eager mode.
For example:
>>> import oneflow as flow >>> loss_fn = flow.nn.MSELoss(reduction="sum") >>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1)) >>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6) >>> class LinearTrainGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = model ... self.loss_fn = loss_fn ... # Add an optimizer ... self.add_optimizer(optimizer) ... def build(self, x, y): ... y_pred = self.model(x) ... loss = self.loss_fn(y_pred, y) ... # Call loss tensor's backward(), loss tensor must be a scalar tensor ... loss.backward() ... return loss >>> linear_graph = LinearTrainGraph() >>> x = flow.randn(10, 3) >>> y = flow.randn(10) >>> model.train() # make model executing in training mode Sequential( (0): Linear(in_features=3, out_features=1, bias=True) (1): Flatten(start_dim=0, end_dim=1) ) >>> for t in range(3): ... loss = linear_graph(x, y)
- Parameters
optim (oneflow.optim.Optimizer) – The optimizer.
lr_sch – The learning rate scheduler, see oneflow.optim.lr_scheduler.
is_sparse – When set to be True, treat optim as a sparse optimizer. Default is False.
-
set_grad_scaler
(grad_scaler: Optional[oneflow.amp.grad_scaler.GradScaler] = None)¶ Set the GradScaler for gradient and loss scaling.
-
state_dict
(destination=None) → Dict[str, Union[Dict[str, oneflow._oneflow_internal.Tensor], oneflow._oneflow_internal.Tensor]]¶ Returns a dictionary containing a whole state of the graph.
States of modules/optimizers/lr schedulers in a graph are included.
Keys of modules’ state dict are corresponding to their name in the graph. Values of modules’ state dict are corresponding to their nn.Module’s state dict.
Other keys and tensors are states of optimizers/lr schedulers/etc.
- Returns
a dictionary containing the whole state of the graph.
- Return type
dict
-
load_state_dict
(state_dict: Dict[str, Union[Dict[str, oneflow._oneflow_internal.Tensor], oneflow._oneflow_internal.Tensor]], strict: bool = True)¶ Copies module’s states and other graph states from
state_dict
into this graph. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’snn.Graph.state_dict()
function.- Parameters
state_dict (dict) – a dict containing module’s states and other graph states.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this graph’snn.Graph.state_dict()
function. Default:True
.
Note
nn.Graph’s state dict can only be loaded before the first call of a graph.
-
property
name
¶ Name auto-generated for this graph.
-
debug
(v_level: int = 0, ranks: Optional[Union[int, List[int]]] = None, mode: bool = True) → None¶ Open or close debug mode of the graph.
If in debug mode, logs of computation graph building infos or warnings will be printed. Otherwise, only errors will be printed.
Each nn.Module inside a nn.Graph also has a debug() method to enable debug mode.
Use
v_level
to choose verbose debug info level, default level is 0, max level is 3.v_level
0 will print warning and graph building stages.v_level
1 will additionally print graph build info of each nn.Module.v_level
2 will additionally print graph build info of each operation.v_level
3 will additionally print more detailed info of each operation.Use
ranks
to choose which rank to print the debug information.For example:
g = CustomGraph() g.debug() # Open debug mode out_tensors = g(input_tensors) # Will print log for debug at the first call
- Parameters
v_level (int) – choose verbose debug info level, default v_level is 0, max v_level is 3.
ranks (int or list(int)) – choose ranks to print the debug information. Default rank
0
. You can choose any valid rank. Ranks equals-1
means debug on all ranks.mode (bool) – whether to set debug mode (
True
) or not (False
). Default:True
.
-
__repr__
()¶ For printing the graph structure.
The graph structure can be printed after graph instantiation.
After the first call of graph, inputs and outputs will be added to the graph structure.
For example:
g = CustomGraph() print(g) out_tensors = g(input_tensors) print(g) # Inputs and Outputs infos are added
-
class
oneflow.nn.graph.graph_config.
GraphConfig
¶ For configuration of nn.Graph.
-
enable_amp
(mode: bool = True)¶ If set to true, then graph will use mixed precision mode, it means use both float16 and float32 during model training.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_amp(True) # Use mixed precision mode. def build(self, x): return self.linear(x) graph = Graph()
- Parameters
mode (bool, optional) – The default vaule is True.
-
allow_fuse_model_update_ops
(mode: bool = True)¶ If set to true, try to fuse cast + scale + l1_l2_regularize_gradient + model_update to one op to improve performance.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.allow_fuse_model_update_ops(True) def build(self, x): return self.linear(x) graph = Graph()
- Parameters
mode (bool, optional) – The default vaule is True.
-
allow_fuse_add_to_output
(mode: bool = True)¶ If set to true, try to fuse a binary element-wise add operetor to one of the predecessors to improve performance.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.bn1 = flow.nn.BatchNorm1d(100) self.config.allow_fuse_add_to_output(True) def build(self, x): bn = self.bn1(x) out = bn + x return out graph = Graph()
- Parameters
mode (bool, optional) – The default vaule is True.
-
allow_fuse_cast_scale
(mode: bool = True)¶ If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance.
For example:
import oneflow as flow def model(x): return flow.mul(1,flow.cast(x,flow.int8)) class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.m=model self.config.allow_fuse_cast_scale(True) def build(self, x): return self.m(x) graph = Graph()
- Parameters
mode (bool, optional) – The default vaule is True.
-
set_gradient_accumulation_steps
(value)¶ Set num of steps to accumulate gradient.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) # Let graph do gradient accumulation, such as pipelining parallelism depends on gradient accumulation. self.config.set_gradient_accumulation_steps(4) def build(self, x): return self.linear(x) graph = Graph()
- Parameters
value (int) – num of steps.
-
set_zero_redundancy_optimizer_mode
(mode: str = 'distributed_split')¶ Set mode to remove redundancy of optimizer states. This optimzation will reduce optimizer states memory consumption as described by ZeRO https://arxiv.org/abs/1910.02054 .
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.set_zero_redundancy_optimizer_mode("distributed_split") def build(self, x): return self.linear(x) graph = Graph()
- Parameters
mode (str) – “distributed_split” or “non_distributed”. “distributed_split” mode will shard each optimizer state across devices. “non_distributed” mode will place each optimizer state to only one device.
-
set_zero_redundancy_optimizer_min_size_after_split
(value)¶ Set the min size of optimizer state/grad/parameter after split.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.set_zero_redundancy_optimizer_mode("distributed_split") self.config.set_zero_redundancy_optimizer_min_size_after_split(1) def build(self, x): return self.linear(x) graph = Graph()
- Parameters
value (int) – min size value.
-
enable_xla_jit
(value=True)¶ Whether use xla_jit in xrt or not.
When this option enable, oneflow will check all operators is supported by xla_jit or not. Clustering supported operators as subgraph, then runing subgraph by xla_jit.
If you need to use XLA to optimize the model running speed, you need to compile the XLA version of oneflow.
Tutorial for build with XLA:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/xrt/README.md#build-with-xla
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_xla_jit(True) # Use xla_jit in xrt. def build(self, x): return self.linear(x) graph = Graph()
- Parameters
value (bool, optional) – The default vaule is True.
-
enable_tensorrt
(value=True)¶ Whether use tensorrt in xrt or not.
When this option enable, oneflow will check all operators is supported by tensorrt or not. Clustering supported operators as subgraph, then runing subgraph by tensorrt.
TensorRT: https://developer.nvidia.com/tensorrt
If you need to use TensorRT to optimize the model running speed, you need to compile the TensorRT version of oneflow.
Tutorial for build with TensorRT:
https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/xrt/README.md#build-with-tensorrt
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_tensorrt(True) # Use tensorrt in xrt. def build(self, x): return self.linear(x) graph = Graph()
- Parameters
value (bool, optional) – The default vaule is True.
-
enable_openvino
(value=True)¶ Whether use openvino in xrt or not.
When this option enable, oneflow will check all operators is supported by openvino or not. Clustering supported operators as subgraph, then runing subgraph by openvino.
Please note that, openvino only support inference mode.
OpenVINO: https://www.intel.com/content/www/us/en/developer/tools/openvino-toolkit/overview.html
It is also necessary to compile the XLA or TensorRT version of oneflow, tutorial: https://github.com/Oneflow-Inc/oneflow/tree/master/oneflow/xrt#readme
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear = flow.nn.Linear(3, 8, False) self.config.enable_openvino(True) # Use openvino in xrt. def build(self, x): return self.linear(x) graph = Graph()
- Parameters
value (bool, optional) – The default vaule is True.
-
enable_cudnn_conv_heuristic_search_algo
(mode: bool = True)¶ Whether enable cudnn conv operatioin to use heuristic search algorithm.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.m = flow.nn.Conv2d(16, 32, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) # Do not enable the cudnn conv operation to use the heuristic search algorithm. self.config.enable_cudnn_conv_heuristic_search_algo(False) def build(self, x): return self.m(x) graph = Graph()
- Parameters
mode (bool, optional) – The default vaule is True.
-
-
class
oneflow.nn.graph.block_config.
BlockConfig
¶ Configurations on Module Block in nn.Graph.
When an nn.Module is added into an nn.Graph, it is wrapped into a ModuleBlock. You can set or get optimization configs on an nn.Module with it’s ModuleBlock.config.
-
property
stage_id
¶ Set/Get stage id of nn.Module/ModuleBlock in pipeline parallelism.
When calling stage_id(value: int = None), set different module’s stage id to hint the graph preparing right num of buffers in pipeline.
For example:
# m_stage0 and m_stage1 are the two pipeline stages of the network, respectively. # We can set Stage ID by setting the config.stage_id attribute of Module. # The Stage ID is numbered starting from 0 and increasing by 1. self.module_pipeline.m_stage0.config.stage_id = 0 self.module_pipeline.m_stage1.config.stage_id = 1
-
property
activation_checkpointing
¶ Set/Get whether do activation checkpointing in this nn.Module.
For example:
import oneflow as flow class Graph(flow.nn.Graph): def __init__(self): super().__init__() self.linear1 = flow.nn.Linear(3, 5, False) self.linear2 = flow.nn.Linear(5, 8, False) self.linear1.config.activation_checkpointing = True self.linear2.config.activation_checkpointing = True def build(self, x): y_pred = self.linear1(x) y_pred = self.linear2(y_pred) return y_pred graph = Graph()
-
property