Pytorch 容器 - 3. Module的参数添加:register_parameter(),register_buffer()
Module类内置了很多函数,其中本文主要介绍常用的属性设置函数,包括向module添加参数的register_parameter(),register_buffer()。官方文档如下:Module — PyTorch 1.7.0 documentation
这两种方法均可以往模型中额外添加参数。不同的是register_parameter() 添加的参数在模型训练时可以正常更新,但register_buffer()则不进行参数更新。在保存模型时,两者的参数都会保存。
示例1:分别使用以下四种形式来定义参数,查看网络更新的参数
- 常用的 nn.Sequential() ,即nn.Module子类的形式
- 使用 register_buffer() 定义参数
- 使用register_parameter() 定义参数
- 使用 Net 类的自定义属性定义参数
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super().__init__()# 使用 nn.Sequential定义参数self.features = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))# 使用 register_buffer()定义参数self.register_buffer('reg_buffer', torch.randn(1, 6))# 使用 register_parameter()定义一组参数self.register_parameter('reg_param', nn.Parameter(torch.randn(1, 6)))# 按照类的属性定义普通变量self.param_attr = torch.randn(1, 6)def forward(self, x):return xmodel = Net()
for item in model.named_parameters():print(item[0], item[1])# reg_param Parameter containing:
# tensor([[ 1.3185, -1.8815, 0.1480, -0.9918, 1.0958, -0.6511]],
# requires_grad=True)
# features.0.weight Parameter containing:
# tensor([[[[-0.0772, -0.2269, 0.2916],
# [ 0.3244, 0.1007, 0.1207],
# [ 0.3328, -0.2957, 0.1106]]],
#
#
# [[[-0.1478, -0.2451, -0.0846],
# [ 0.2457, 0.2737, -0.1498],
# [-0.1811, -0.0386, -0.2116]]]], requires_grad=True)
# features.0.bias Parameter containing:
# tensor([-0.0358, 0.0178], requires_grad=True)
从结果可以看出,网络中可以训练的参数仅包括:使用 nn.Sequential()定义的参数 和 使用register_parameter() 定义的参数,而buffers和普通类属性定义的参数不可以更新。
示例2:同样以四种形式来定义参数,查看模型可以保存的参数
print(model.state_dict().keys())
# odict_keys(['reg_param', 'reg_buffer', 'features.0.weight', 'features.0.bias'])
从结果可以看出,模型保存的参数包括:使用 nn.Sequential()定义的参数、使用register_parameter() 定义的参数 以及 使用register_buffer() 定义的参数,而普通类属性定义的参数不会被保存。