> 文章列表 > 加速深层网络收敛:批量规范化

加速深层网络收敛:批量规范化

加速深层网络收敛:批量规范化

(一)背景

在多层感知机中,中间层的变量可能具有更广的变化范围(由于输入的变化造成的),这可能造成不同层之间参数量级的不统一,造成深层网络的收敛困难。

因此需要批量规范化技术。

(二)原理

批量规范化的原理如下:

在每次训练迭代过程中,我们首先规范化小批量,使均值为 0 0 0、方差为 1 1 1,然后对其应用比例系数和比例偏移。

B N ( x ) = γ ∗ x − μ σ + β BN(x)=\\gamma\\ *\\ \\frac{x-\\mu}{\\sigma}+\\beta BN(x)=γ  σxμ+β

其中 γ \\gamma γ β \\beta β分别是拉伸参数和偏移参数,均为可学习参数。

其中 σ \\sigma σ计算时通常会加上噪声,防止除以 0 0 0的情况出现。

应用批量规范化将每一层的输入主动居中,从而防止了中间层剧烈的变化。

(三)使用

1.全连接层

通常将批量规范化层置于全连接层与激活函数之间:

y_hat = Relu(BatchNorm1d(Linear(X)))

2.卷积层

通常将批量规范化置于卷积层与激活函数之间:

y_hat = Relu(BatchNorm2d(Conv2d(X)))

3.注意事项

(1)在应用批量规范化时,批量大小的设置往往会变得更加重要。只有应用足够大的批量,批量规范化才是稳定的。

其最适合解决 50 50 50~ 100 100 100的中等批量大小问题;

(2)在训练模式与测试模式下,批量规范化层的表现不同:

在训练模式下,由于无法计算整个数据集的均值和方差,所以对小批量使用规范化;

在测试模式下,可以计算整个数据集的均值和方差,通常应用滑动平均估算来实现。

(四)实现

class BatchNorm(nn.Module):""" 批量规范化 """# num_features为全连接层的特征数或者卷积层的通道数def __init__(self, num_features, num_dims):super(BatchNorm, self).__init__()assert num_dims in (2, 4)if num_dims == 2:   # 全连接层shape = (1, num_features)else:               # 卷积层shape = (1, num_features, 1, 1)self.gamma = nn.Parameter(torch.ones(shape))    # 比例拉伸self.beta = nn.Parameter(torch.zeros(shape))    # 比例偏移self.moving_mean = torch.zeros(shape)   # 均值self.moving_var = torch.ones(shape)     # 方差def forward(self, X):if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)X, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)   # 其中eps为噪声,momentum为滑动平均估算的参数return X

gammabeta的维度与输入维度数量一致,除了相应的特征维度,其余维度大小均为 1 1 1

因为我们要计算输入的每一个特征的均值与方差。

直观理解的话,特征就是全连接层输入((batch_size, features))的每一列、图片的每一层。

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):if not torch.is_grad_enabled():  # 非训练模式X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:                            # 训练模式assert len(X.shape) in (2, 4)if len(X.shape) == 2:  # 全连接层#  X shape(batch_size, features)mean = X.mean(dim=0)                    # mean shape(features)var = ((X - mean) ** 2).mean(dim=0)     # var shape(features)else:                  # 卷积层#  X shape(batch_size, channels, h, w)mean = X.mean(dim=(0, 2, 3), keepdim=True)                  # mean shape(1, channels, 1, 1)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)   # var shape(1, channels, 1, 1)X_hat = (X - mean) / torch.sqrt(var + eps)  # 广播机制moving_mean = momentum * moving_mean + (1 - momentum) * meanmoving_var = momentum * moving_var + (1 - momentum) * varY = gamma * X_hat + betareturn Y, moving_mean.data, moving_var.data