oneflow.comm.reduce_scatter_tensor¶
-
oneflow.comm.
reduce_scatter_tensor
(output_tensor, input_tensor)¶ Reduces, then scatters a tensor to all ranks.
- Parameters
output (Tensor) – Output tensor. It should have the same size across all ranks.
input (Tensor) – Input tensor to be reduced and scattered. Its size should be output tensor size times the world size. The input tensor can have one of the following shapes: (i) a concatenation of the output tensors along the primary dimension, or (ii) a stack of the output tensors along the primary dimension. For definition of “concatenation”, see
oneflow.cat()
. For definition of “stack”, seeoneflow.stack()
.
For example:
>>> # We have 1 process groups, 2 ranks. >>> # All tensors below are of flow.int64 dtype and on CUDA devices. >>> import oneflow as flow >>> tensor_in = flow.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=flow.int64, device="cuda") >>> tensor_in tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]], device='cuda:0', dtype=oneflow.int64) >>> # Output in concatenation form >>> tensor_out = flow.zeros(2, 3, dtype=flow.int64, device="cuda") >>> flow.comm.reduce_scatter_tensor(tensor_out, tensor_in) >>> # result on rank0 >>> tensor_out tensor([[ 2, 4, 6], [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out tensor([[14, 16, 18], [20, 22, 24]], device='cuda:1', dtype=oneflow.int64) >>> # Output in stack form >>> tensor_in2 = tensor_in.reshape(2, 3, 2) >>> tensor_out2 = flow.zeros(2, 3, dtype=flow.int64, device="cuda") >>> flow.comm.reduce_scatter_tensor(tensor_out2, tensor_in2) >>> # result on rank0 >>> tensor_out2 tensor([[ 2, 4, 6], [ 8, 10, 12]], device='cuda:0', dtype=oneflow.int64) >>> # result on rank1 >>> tensor_out2 tensor([[14, 16, 18], [20, 22, 24]], device='cuda:1', dtype=oneflow.int64)