oneflow.comm.all_gather_into_tensor

oneflow.comm.all_gather_into_tensor(output_tensor, input_tensor)

Gather tensors from all ranks and put them in a single output tensor.

Parameters
  • output_tensor (Tensor) – Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of “concatenation”, see oneflow.cat(); (ii) a stack of all the input tensors along the primary dimension; for definition of “stack”, see oneflow.stack(). Examples below may better explain the supported output forms.

  • input_tensor (Tensor) – Tensor to be gathered from current rank. The input tensors in this API must have the same size across all ranks.

For example:

>>> # We have 1 process groups, 2 ranks.
>>> # All tensors below are of flow.int64 dtype and on CUDA devices.
>>> import oneflow as flow
>>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device="cuda") + flow.env.get_rank() * 6
>>> tensor_in 
tensor([[1, 2, 3],
        [4, 5, 6]], device='cuda:0', dtype=oneflow.int64)
>>> # Output in concatenation form
>>> tensor_out = flow.zeros(4, 3, dtype=flow.int64, device="cuda")
>>> flow.comm.all_gather_into_tensor(tensor_out, tensor_in)
>>> # result on rank0
>>> tensor_out 
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]], device='cuda:0', dtype=oneflow.int64)
>>> # result on rank1
>>> tensor_out 
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]], device='cuda:1', dtype=oneflow.int64)
>>> # Output in stack form
>>> tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device="cuda")
>>> flow.comm.all_gather_into_tensor(tensor_out2, tensor_in)
>>> # result on rank0
>>> tensor_out2 
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]]], device='cuda:0', dtype=oneflow.int64)
>>> # result on rank1
>>> tensor_out2 
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]]], device='cuda:1', dtype=oneflow.int64)