oneflow.chunk

oneflow.chunk()

Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be bigger if the tensor size along the given dimension dim is not divisible by chunks.

Parameters
  • input (oneflow.Tensor) – The tensor to split.

  • chunks (int) – Number of chunks to return.

  • dim (int) – Dimension along which to split the tensor.

Returns

List of Tensors.

For example:

>>> import oneflow as flow
>>> import numpy as np

>>> arr = np.random.randn(5, 3, 6, 9).astype(np.float32)
>>> input = flow.tensor(arr)
>>> output = []
>>> chunks = 3
>>> output = flow.chunk(input, chunks=chunks, dim=2)
>>> out_shape = []
>>> for i in range(0, chunks):
...     out_shape.append(output[i].numpy().shape)
>>> out_shape
[(5, 3, 2, 9), (5, 3, 2, 9), (5, 3, 2, 9)]