oneflow.comm.all_gather_into_tensor¶
-
oneflow.comm.
all_gather_into_tensor
(output_tensor, input_tensor)¶ Gather tensors from all ranks and put them in a single output tensor.
- Parameters
output_tensor (Tensor) – Output tensor to accommodate tensor elements from all ranks. It must be correctly sized to have one of the following forms: (i) a concatenation of all the input tensors along the primary dimension; for definition of “concatenation”, see
oneflow.cat()
; (ii) a stack of all the input tensors along the primary dimension; for definition of “stack”, seeoneflow.stack()
. Examples below may better explain the supported output forms.input_tensor (Tensor) – Tensor to be gathered from current rank. The input tensors in this API must have the same size across all ranks.
For example:
>>> # We have 1 process groups, 2 ranks. >>> # All tensors below are of flow.int64 dtype and on CUDA devices. >>> import oneflow as flow >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6]], dtype=flow.int64, device="cuda") + flow.env.get_rank() * 6 >>> tensor_in tensor([[1, 2, 3], [4, 5, 6]], device='cuda:0', dtype=oneflow.int64) >>> # Output in concatenation form >>> tensor_out = flow.zeros(4, 3, dtype=flow.int64, device="cuda") >>> flow.comm.all_gather_into_tensor(tensor_out, tensor_in) >>> # result on rank0 >>> tensor_out tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:1', dtype=oneflow.int64) >>> # Output in stack form >>> tensor_out2 = flow.zeros(2, 3, 2, dtype=flow.int64, device="cuda") >>> flow.comm.all_gather_into_tensor(tensor_out2, tensor_in) >>> # result on rank0 >>> tensor_out2 tensor([[[ 1, 2], [ 3, 4], [ 5, 6]], [[ 7, 8], [ 9, 10], [11, 12]]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out2 tensor([[[ 1, 2], [ 3, 4], [ 5, 6]], [[ 7, 8], [ 9, 10], [11, 12]]], device='cuda:1', dtype=oneflow.int64)