oneflow.tensor_split

oneflow.tensor_split()

Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections . The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.tensor_split.html.

Parameters
  • input (Tensor) – the input tensor.

  • indices_or_sections (int or a list) – If indices_or_sections is an integer n , input is split into n sections along dimension dim.If input is divisible by n along dimension dim, each section will be of equal size, input.size (dim) / n. If input is not divisible by n, the sizes of the first int(input.size(dim) % n). sections will have size int(input.size(dim) / n) + 1, and the rest will have size int(input.size(dim) / n). If indices_or_sections is a list or tuple of ints, then input is split along dimension dim at each of the indices in the list, tuple or tensor. For instance, indices_or_sections=[2, 3] and dim=0 would result in the tensors input[:2], input[2:3], and input[3:].If indices_or_sections is a tensor, it must be a zero-dimensional or one-dimensional long tensor on the CPU.

  • dim (int) – dimension along which to split the tensor.

Returns

the output TensorTuple.

Return type

oneflow.TensorTuple

For example:

>>> import oneflow as flow

>>> input = flow.rand(3,4,5)
>>> output = flow.tensor_split(input,(2,3),2)
>>> output[0].size()
oneflow.Size([3, 4, 2])
>>> output[1].size()
oneflow.Size([3, 4, 1])
>>> output[2].size()
oneflow.Size([3, 4, 2])