> 文章列表 > Pytorch基础 - 4. torch.expand() 和 torch.repeat()

Pytorch基础 - 4. torch.expand() 和 torch.repeat()

Pytorch基础 - 4. torch.expand() 和 torch.repeat()

目录

1.  torch.expand(*sizes)

2. torch.repeat(*sizes)

3. 两者内存占用的区别


在PyTorch中有两个函数可以用来扩展某一维度张量,即 torch.expand() 和 torch.repeat()

1.  torch.expand(*sizes)

含义】将输入张量在大小为1的维度上进行拓展,并返回扩展更大后的张量

【参数】sizes的shape为torch.Size 或 int,指拓展后的维度, 当值为-1的时候,表示维度不变

import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)print(y1.shape)  # torch.Size([4, 3])y2 = x.expand(6, -1)print(y2.shape)  # torch.Size([6, 3])

2. torch.repeat(*sizes)

【含义】沿着特定维度扩展张量,并返回扩展后的张量

【参数】sizes的shape为torch.Size 或 int,指对当前维度扩展的倍数

import torchif __name__ == '__main__':x = torch.rand(2, 3)y1 = x.repeat(4, 2)print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)y2 = x.repeat(2, 3)print(x.data_ptr(), y1.data_ptr())  # 52364352 52364352print(x.data_ptr(), y2.data_ptr())  # 52364352 8852096