oneflow.nn.Graph.add_optimizer¶
-
Graph.
add_optimizer
(optim: oneflow.nn.optimizer.optimizer.Optimizer, *, lr_sch: Optional[oneflow.nn.optimizer.lr_scheduler.LRScheduler] = None, is_sparse: bool = False)¶ Add an optimizer, an learning rate scheduler to the graph.
To do training with nn.Graph, you should do 2 more things:
Add at least one optimizer(learning rate schedulers are optional) with
add_optimizer()
method.Call loss tensor’s
backward()
method inbuild()
method.
Note that the computaion graph will automatically execute these methods:
optimizer’s
clip_grad()
if a optimizer is set to do grad cliping.optimizer’s
step()
.optimizer’s
zero_grad()
.learn rate scheduler’s
step()
.
Also note that only scalar tensor are allowed to call
backward()
innn.Graph.build()
for the moment. So you may call methods such asTensor.mean()
to make the loss tensor a scalar tensor.Note
If you want to output the learning rate information for each step, set the
verbose
parameter of thelr_scheduler
toTrue
, and you will see the result at rank 0.This feature is the same as eager mode.
For example:
>>> import oneflow as flow >>> loss_fn = flow.nn.MSELoss(reduction="sum") >>> model = flow.nn.Sequential(flow.nn.Linear(3, 1), flow.nn.Flatten(0, 1)) >>> optimizer = flow.optim.SGD(model.parameters(), lr=1e-6) >>> class LinearTrainGraph(flow.nn.Graph): ... def __init__(self): ... super().__init__() ... self.model = model ... self.loss_fn = loss_fn ... # Add an optimizer ... self.add_optimizer(optimizer) ... def build(self, x, y): ... y_pred = self.model(x) ... loss = self.loss_fn(y_pred, y) ... # Call loss tensor's backward(), loss tensor must be a scalar tensor ... loss.backward() ... return loss >>> linear_graph = LinearTrainGraph() >>> x = flow.randn(10, 3) >>> y = flow.randn(10) >>> model.train() # make model executing in training mode Sequential( (0): Linear(in_features=3, out_features=1, bias=True) (1): Flatten(start_dim=0, end_dim=1) ) >>> for t in range(3): ... loss = linear_graph(x, y)
- Parameters
optim (oneflow.optim.Optimizer) – The optimizer.
lr_sch – The learning rate scheduler, see oneflow.optim.lr_scheduler.
is_sparse – When set to be True, treat optim as a sparse optimizer. Default is False.