Pytorch 使用经验分享知识点总结

在科研过程中总结的一些琐碎的pytorch相关知识点。
1. 数据加载
-
锁页内存(
pin_memory)是决定数据放在锁业内存还是硬盘的虚拟内存中,默认值为 False。如果设置为True,则表示数据放在锁业内存中。注意:显卡中的内存全部是锁页内存,所以放在锁页内存中可以加快读取速度。当计算机内存充足时,可将该值设置为 True。这一参数一般在data_loader()函数中设置。 -
num_worker取值最好是 2的幂次方-1, 如 0,1,3,7 等,因为会自动加 1. 默认值为 1. 这一参数一般在data_loader()函数中设置。 -
GPU 利用率为低是因为显卡在等数据,解决办法(1)优化
data_loader()函数;(2)增大batch size 等
2. 数据操作
- pytorch两个基本对象:
Tensor(张量)和Variable(变量)。 torch.Tensor与torch.tensor的区别:
torch.Tensor(data):将数据转化torch.FloatTensor类型。
torch.tensor(data):根据数据类型或者dtype参数值将数据转化为torch.FloatTensor、torch.LongTensor、torch.DoubleTensor等类型。torch.contiguous():类似于 C++ 中的深拷贝。详解见此篇博客。torch.stack()作用:用于连接大小相同的张量,并扩展维度,类比torch.cat(). 注意:在哪个维度上操作,就将 dim 设置为哪个维度。 详解见此篇博客。- 使用
torch.zeros()创建的张量默认在 CPU 上,如要在 GPU 上使用记得进行数据转移。 - 解决 torch 对象打印时有省略号的问题:
torch.set_printoptions(threshold=np.inf),该命令多用于打印完整日志。 - numpy 类型数据只能在 CPU 上运行。注意数据在torch类型与numpy类型间相互转换时数据的存放位置(如:不能将GPU上的张量数据直接转化为numpy类型数据)。
3. 模型操作
3.1 模式切换
model.eval()与model.train()区别在于是否启用 归一化层 + dropout,前者不启用,后者启用。
3.2 梯度更新
-
Module中的层在定义时,相关Variable的requires_grad参数默认是True。而用户手动定义Variable时,参数requires_grad默认值是False,volatile值也默认为False。volatile的优先级比requires_grad高,volatile属性为True的节点不会求导(所以可以在测试阶段设置为 True)。 如果要修改可使用variable_name.require_grad_(True)实现。 -
反向传播中梯度回传与更新的实现三步走: (1)
optimizer.zero_grad()(梯度清零)(2)loss.backward()(梯度回传)(3)optimizer.step()(梯度更新) -
model.zero_grad ()和optimizer.zero_grad ()使用区别:当optimizer = optim.Optimizer (net.parameters ()),即网络中参数均未冻结,全部需要更新时,二者等效,其中Optimizer可以是Adam、SGD等优化器;若网络中部分参数被冻结或多个网络共用同一个优化器,则二者不等价。详解见此篇博客。 -
with torch.no_grad()作用:停止autograd模块的工作,以起到加速和节省显存的作用。一般用在验证和测试阶段。注意:新版本Pytorch中,volatile已被弃用,需替换为:with torch.no_grad().
3.3 模型保存与加载
torch.save(model, path):将训练好的模型 model 保存至 path 路径下。torch.load(model_path, map_location):将给定路径的预训练模型加载至指定设备上,详解见此篇博客。
参考资料
- Pytorch中contiguous()函数理解_.contiguous()_清晨的光明的博客-CSDN博客
- pytorch拼接函数:torch.stack()和torch.cat()–详解及例子_python torch拼接_紫芝的博客-CSDN博客
- pytorch之model.zero_grad() 与 optimizer.zero_grad()_models.zero_grad()_旺旺棒棒冰的博客-CSDN博客
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()_宁静致远*的博客-CSDN博客
