oneflow.utils.global_view.to_global

oneflow.utils.global_view.to_global(input, placement=None, sbp=None, warn_on_non_tensor_leaf=True, **kwargs)

Converts the input tensor or input tensor(s) in list/tuple/dict to global tensor(s).

Note

Both placement and sbp are required if the input is local, otherwise at least one of placement and sbp is required.

Parameters
  • input (oneflow.Tensor/None/list/tuple/dict) – the input that needs to be converted.

  • placement (oneflow.placement, optional) – the desired placement of the input. Default: None

  • sbp (oneflow.sbp.sbp, list/tuple of oneflow.sbp.sbp or Callable[[Tensor], oneflow.sbp.sbp], optional) – the desired sbp of the input or self-defined functions in order to specify SBP. Default: None

  • warn_on_non_tensor_leaf (bool, optional) – whether to warn when the leaf is not a tensor. Default: True

Returns

The converted input.

For a tensor input: please refer to the examples in oneflow.Tensor.to_global().

For an input of other type (take a state dict as an example):

>>> # Run on 2 ranks respectively
>>> import oneflow as flow
>>> from oneflow import nn
>>> placement = flow.placement("cpu", ranks=[0, 1]) 
>>> sbp = (flow.sbp.broadcast,) 
>>> model = nn.Sequential(nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 2)) 
>>> global_state_dict = flow.utils.global_view.to_global(model.state_dict(), placement, sbp) 
>>> for val in state_dict.values(): 
>>>     print(val.is_global) 
>>> # results on rank 0
True
True
True
True
>>> # results on rank 1
True
True
True
True

Note

For the input of dict type, such as the state dict of the model, the unified sbp cannot be used when calling the to_global method, and the sbp needs to be specialized. Usually used for making graph models’s state dict global.

If you want to do the split(0) operation, but there are tensors that cannot be split by dim 0, then these tensors can specify sbp. It is worth noting that, for a tensor of shape (1, n), you can specify SBP is oneflow.sbp.split(1). For example:

flow.utils.global_view.to_global(state_dict, placement=placement, sbp=get_sbp)
# Defines a function to return the specified SBP.
def get_sbp(state_dict, tensor):
    if tensor is state_dict["System-Train-TrainStep"]:
        return oneflow.sbp.broadcast
    if tensor is state_dict["module_pipeline"]["m_stage3.linear.weight"]:
        return oneflow.sbp.split(1)
    if tensor is state_dict["module_pipeline"]["m_stage3.linear.bias"]:
        return oneflow.sbp.broadcast
    return oneflow.sbp.split(0)