加速深层网络收敛:批量规范化
(一)背景
在多层感知机中,中间层的变量可能具有更广的变化范围(由于输入的变化造成的),这可能造成不同层之间参数量级的不统一,造成深层网络的收敛困难。
因此需要批量规范化技术。
(二)原理
批量规范化的原理如下:
在每次训练迭代过程中,我们首先规范化小批量,使均值为 0 0 0、方差为 1 1 1,然后对其应用比例系数和比例偏移。
B N ( x ) = γ ∗ x − μ σ + β BN(x)=\\gamma\\ *\\ \\frac{x-\\mu}{\\sigma}+\\beta BN(x)=γ ∗ σx−μ+β
其中 γ \\gamma γ和 β \\beta β分别是拉伸参数和偏移参数,均为可学习参数。
其中 σ \\sigma σ计算时通常会加上噪声,防止除以 0 0 0的情况出现。
应用批量规范化将每一层的输入主动居中,从而防止了中间层剧烈的变化。
(三)使用
1.全连接层
通常将批量规范化层置于全连接层与激活函数之间:
y_hat = Relu(BatchNorm1d(Linear(X)))
2.卷积层
通常将批量规范化置于卷积层与激活函数之间:
y_hat = Relu(BatchNorm2d(Conv2d(X)))
3.注意事项
(1)在应用批量规范化时,批量大小的设置往往会变得更加重要。只有应用足够大的批量,批量规范化才是稳定的。
其最适合解决 50 50 50~ 100 100 100的中等批量大小问题;
(2)在训练模式与测试模式下,批量规范化层的表现不同:
在训练模式下,由于无法计算整个数据集的均值和方差,所以对小批量使用规范化;
在测试模式下,可以计算整个数据集的均值和方差,通常应用滑动平均估算来实现。
(四)实现
class BatchNorm(nn.Module):""" 批量规范化 """# num_features为全连接层的特征数或者卷积层的通道数def __init__(self, num_features, num_dims):super(BatchNorm, self).__init__()assert num_dims in (2, 4)if num_dims == 2: # 全连接层shape = (1, num_features)else: # 卷积层shape = (1, num_features, 1, 1)self.gamma = nn.Parameter(torch.ones(shape)) # 比例拉伸self.beta = nn.Parameter(torch.zeros(shape)) # 比例偏移self.moving_mean = torch.zeros(shape) # 均值self.moving_var = torch.ones(shape) # 方差def forward(self, X):if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)X, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9) # 其中eps为噪声,momentum为滑动平均估算的参数return X
gamma
与beta
的维度与输入维度数量一致,除了相应的特征维度,其余维度大小均为 1 1 1。
因为我们要计算输入的每一个特征的均值与方差。
直观理解的话,特征就是全连接层输入((batch_size, features)
)的每一列、图片的每一层。
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled(): # 非训练模式X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else: # 训练模式assert len(X.shape) in (2, 4)if len(X.shape) == 2: # 全连接层# X shape(batch_size, features)mean = X.mean(dim=0) # mean shape(features)var = ((X - mean) ** 2).mean(dim=0) # var shape(features)else: # 卷积层# X shape(batch_size, channels, h, w)mean = X.mean(dim=(0, 2, 3), keepdim=True) # mean shape(1, channels, 1, 1)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # var shape(1, channels, 1, 1)X_hat = (X - mean) / torch.sqrt(var + eps) # 广播机制moving_mean = momentum * moving_mean + (1 - momentum) * meanmoving_var = momentum * moving_var + (1 - momentum) * varY = gamma * X_hat + betareturn Y, moving_mean.data, moving_var.data