oneflow.utils.global_view.to_global¶
-
oneflow.utils.global_view.
to_global
(input, placement=None, sbp=None, warn_on_non_tensor_leaf=True, **kwargs)¶ Converts the input tensor or input tensor(s) in list/tuple/dict to global tensor(s).
Note
Both placement and sbp are required if the input is local, otherwise at least one of placement and sbp is required.
- Parameters
input (oneflow.Tensor/None/list/tuple/dict) – the input that needs to be converted.
placement (oneflow.placement, optional) – the desired placement of the input. Default: None
sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp or Callable[[Tensor], oneflow.sbp.sbp], optional) – the desired sbp of the input or self-defined functions in order to specify SBP. Default: None
warn_on_non_tensor_leaf (bool, optional) – whether to warn when the leaf is not a tensor. Default: True
- Returns
The converted input.
For a tensor input: please refer to the examples in
oneflow.Tensor.to_global()
.For an input of other type (take a state dict as an example):
>>> # Run on 2 ranks respectively >>> import oneflow as flow >>> from oneflow import nn >>> placement = flow.placement("cpu", ranks=[0, 1]) >>> sbp = (flow.sbp.broadcast,) >>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) >>> global_state_dict = flow.utils.global_view.to_global(model.state_dict(), placement, sbp) >>> for val in state_dict.values(): >>> print(val.is_global)
>>> # results on rank 0 True True True True
>>> # results on rank 1 True True True True
Note
For the input of dict type, such as the state dict of the model, the unified sbp cannot be used when calling the to_global method, and the sbp needs to be specialized. Usually used for making graph models’s state dict global.
If you want to do the split(0) operation, but there are tensors that cannot be split by dim 0, then these tensors can specify sbp. It is worth noting that, for a tensor of shape (1, n), you can specify SBP is oneflow.sbp.split(1). For example:
flow.utils.global_view.to_global(state_dict, placement=placement, sbp=get_sbp) # Defines a function to return the specified SBP. def get_sbp(state_dict, tensor): if tensor is state_dict["System-Train-TrainStep"]: return oneflow.sbp.broadcast if tensor is state_dict["module_pipeline"]["m_stage3.linear.weight"]: return oneflow.sbp.split(1) if tensor is state_dict["module_pipeline"]["m_stage3.linear.bias"]: return oneflow.sbp.broadcast return oneflow.sbp.split(0)