> 文章列表 > 【人工智能概论】 网络模块构建工具——torch. nn.Sequential、torch.nn.ModuleList

【人工智能概论】 网络模块构建工具——torch. nn.Sequential、torch.nn.ModuleList

【人工智能概论】 网络模块构建工具——torch. nn.Sequential、torch.nn.ModuleList

【人工智能概论】 网络模块构建工具——torch. nn.Sequential、torch.nn.ModuleList

文章目录

  • 【人工智能概论】 网络模块构建工具——torch. nn.Sequential、torch.nn.ModuleList
  • 一. 简介(相同点)
  • 二. 应用特点(不同点
    • 2.1 nn.Sequential
    • 2.2 nn.ModuleList

一. 简介(相同点)

  • nn.ModuleList 和 nn.Sequential都被用来封装多个层。

二. 应用特点(不同点)

  • nn.Sequential 中 nn.Module 子模块的添加顺序也是其前向传递(forward)的顺序,因此在 forward 函数中可以直接调用;
  • nn.ModuleList 的 nn.Module 子模块间没有顺序依赖关系,因此需要在 forward 函数中显式定义子模块间的前向传递关系;
  • 实际上,就是nn.Sequential有forward()方法,而nn.ModuleList只是将一系列层装入列表,并没有forward()方法。

2.1 nn.Sequential

  • nn.Sequential定义了一个网络,里面的模块是按照顺序进行排列的,因此必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。
  • 代码举例:
class net_seq(nn.Module):def __init__(self):super(net_seq, self).__init__()self.seq = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())      def forward(self, x):return self.seq(x)
  • nn.Sequential中可以使用OrderedDict来指定每个module的名字,而不是采用默认的命名方式(按序号 0,1,2,3…)。
  • 代码举例:
from collections import OrderedDictnet3= nn.Sequential(OrderedDict([('conv', nn.Conv2d(3, 3, 3)),('batchnorm', nn.BatchNorm2d(3)),('activation_layer', nn.ReLU())]))

2.2 nn.ModuleList

  • nn.ModuleList将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。
  • 与 Python 自带的 list 类似,nn.ModuleList有 extend,append 等操作,这使得它更加灵活,extend是添加一个新的modulelist ,append是添加一个新的module。
  • 代码举例:
class LinearNet(nn.Module):def __init__(self, input_size, num_layers, layers_size, output_size):super(LinearNet, self).__init__()self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])self.linears.append(nn.Linear(layers_size, output_size)
  • 有时网络中有很多相似或重复的层,一般会采用 for 循环来创建它们,而不是一行一行地写,比如:
class net_list(nn.Module):def __init__(self):super(net_list, self).__init__()layers = [nn.Linear(10, 10) for i in range(5)]self.linears = nn.ModuleList(layers)def forward(self, x):for layer in self.linears:x = layer(x)return x