> 文章列表 > 【pytorch函数笔记】torch.split

【pytorch函数笔记】torch.split

【pytorch函数笔记】torch.split

官方文档:https://pytorch.org/docs/stable/generated/torch.split.html?highlight=split

torch.split(tensorsplit_size_or_sectionsdim=0)

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will  be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.

张量拆分为。每个块都是原始张量的一个视图。
如果split_size_or_sections是整型,那么张量将被拆分为大小相等的块(如果可能的话)。如果沿着给定维度dim的张量大小不能被split_size整除,则最后一个块将更小。
如果split_size_or_sections是一个列表,那么张量将被拆分为len(split_size _or_section)块,其大小根据split_sze_or_secttions为dim。

参数:

  • tensor (Tensor) – tensor to split.需要分裂的tensor

  • split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk单个块的大小

  • dim (int) – dim默为0,即按行分类;dim=1按列分裂

返回类型:List[Tensor]

import torcha = torch.arange(10).reshape(5, 2)
print(a)
torch.split(a, 2)
torch.split(a, [1, 4])
torch.split(a, 1)
torch.split(a, [3,2])# dim=1的时候,按列分裂a = torch.arange(10).reshape(2, 5)
print(a)
torch.split(a, [3,2],1)

结果: