Pytorch基础 - 4. torch.expand() 和 torch.repeat()
目录
1. torch.expand(*sizes)
2. torch.repeat(*sizes)
3. 两者内存占用的区别
在PyTorch中有两个函数可以用来扩展某一维度的张量,即 torch.expand() 和 torch.repeat()
1. torch.expand(*sizes)
【含义】将输入张量在大小为1的维度上进行拓展,并返回扩展更大后的张量
【参数】sizes的shape为torch.Size 或 int,指拓展后的维度, 当值为-1的时候,表示维度不变
import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)print(y1.shape) # torch.Size([4, 3])y2 = x.expand(6, -1)print(y2.shape) # torch.Size([6, 3])
2. torch.repeat(*sizes)
【含义】沿着特定维度扩展张量,并返回扩展后的张量
【参数】sizes的shape为torch.Size 或 int,指对当前维度扩展的倍数
import torchif __name__ == '__main__':x = torch.rand(2, 3)y1 = x.repeat(4, 2)print(y1.shape) # torch.Size([8, 6])
3. 两者内存占用的区别
torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图
torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间
示例如下:
import torchif __name__ == '__main__':x = torch.rand(1, 3)y1 = x.expand(4, 3)y2 = x.repeat(2, 3)print(x.data_ptr(), y1.data_ptr()) # 52364352 52364352print(x.data_ptr(), y2.data_ptr()) # 52364352 8852096