oneflow.comm.all_gather¶
-
oneflow.comm.
all_gather
(tensor_list, tensor)¶ Gathers tensors from the whole group in a list.
- Parameters
For example:
>>> # We have 1 process groups, 2 ranks. >>> import oneflow as flow >>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() >>> # input on rank0 >>> input tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) >>> # input on rank1 >>> input tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) >>> tensor_list = [flow.zeros(2, 2, dtype=flow.int64) for _ in range(2)] >>> flow.comm.all_gather(tensor_list, input) >>> # result on rank0 >>> tensor_list [tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64), tensor([[2, 3], [4, 5]], device='cuda:0', dtype=oneflow.int64)] >>> # result on rank1 >>> tensor_list [tensor([[1, 2], [3, 4]], device='cuda:1', dtype=oneflow.int64), tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64)]