Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections . The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.tensor_split.html.
input (Tensor) – the input tensor.
indices_or_sections (int or a list) – If indices_or_sections is an integer n , input is split into n sections along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n). sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n). If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU.
dim (int) – dimension along which to split the tensor.
the output TensorTuple.
- Return type
>>> import oneflow as flow >>> input = flow.rand(3,4,5) >>> output = flow.tensor_split(input,(2,3),2) >>> output.size() oneflow.Size([3, 4, 2]) >>> output.size() oneflow.Size([3, 4, 1]) >>> output.size() oneflow.Size([3, 4, 2])