> 文章列表 > PyTorch笔记 - Weight Normalization 权重归一化

PyTorch笔记 - Weight Normalization 权重归一化

PyTorch笔记 - Weight Normalization 权重归一化

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://blog.csdn.net/caroline_wendy/article/details/129708143

深度神经网络的权重归一化(Weight Normalization,简称WN)是一种优化技术,它可以加速训练过程并提高模型的泛化能力。权重归一化的基本思想是将每个神经元的权重向量分解为两个因子:方向和大小。然后,对方向进行归一化,使其具有单位范数,而大小则作为一个可训练的参数。这样做的好处是,它可以减少梯度下降中的路径依赖性,从而避免局部最优和鞍点。此外,权重归一化还可以改善网络的初始化和正则化,因为它可以降低权重矩阵的条件数,并且可以与批次归一化或dropout等技术结合使用。

权重归一化将连接权重向量w在其欧氏范数和其方向上解耦成了参数向量v和参数标量g,即w = g * v / ||v||,其中||v||表示v的欧氏范数。这样,原来对w的优化就转化为对g和v的优化,而g和v分别控制了w的长度和方向。

WN有以下几个优点:

  • WN可以缩放梯度并将其投影到一个远离于w的方向,这有助于矫正梯度更新方向并加速收敛。
  • WN可以自稳定梯度的范数,使得可以使用较大的学习率而不会导致梯度爆炸或消失。
  • WN与样本数量无关,因此可以应用在小批量或动态网络中。
  • WN没有引入额外参数,因此不会增加模型复杂度或显存占用。
  • WN没有计算归一化统计量,因此比需要计算统计量的BN等方法更高效。
  • WN在噪声敏感的环境中表现更好,例如生成模型或强化学习等。

WN可以通过再参数化(Reparameterization)技术实现,在Pytorch或Tensorflow等框架中都有相应的接口。例如,在Pytorch中可以使用nn.utils.weight_norm函数对线性层或卷积层进行WN处理。具体使用方法如下:

import torch.nn as nn
import torch.nn.functional as F# 以一个简单的单隐层网络为例
class Model(nn.Module):def __init__(self, input_dim, output_dim, hidden_size):super(Model, self).__init__()# 使用weight_norm函数对线性层进行WN处理self.dense1 = nn.utils.weight_norm(nn.Linear(input_dim, hidden_size))self.dense2 = nn.utils.weight_norm(nn.Linear(hidden_size, output_dim))def forward(self, x):x = self.dense1(x)x = F.leaky_relu(x)x = self.dense2(x)return x

Paper:

  • Weight Normalization - A Simple Reparameterization to Accelerate Training of Deep Neural Networks
    • Tim Salimans,OpenAI,2016.6.4
  • API: torch.nn.utils.weight_norm

具体定义:类似于矩阵分解

  • weight_g,指定幅度
  • weight_v,指定方向
  • 即把weight,解耦为weight_gweight_v,求两个weight的梯度。

PyTorch笔记 - Weight Normalization 权重归一化

源码:

import torch
import torch.nn as nnbatch_size = 2
feat_dim = 3
hid_dim = 4
inputx = torch.randn(batch_size, feat_dim)
linear = nn.Linear(feat_dim, hid_dim, bias=False)
wn_linear = torch.nn.utils.weight_norm(linear)  # moduleweight_magnitude = torch.tensor([linear.weight[i, :].norm() for i in torch.arange(linear.weight.shape[0])],dtype=torch.float32
).unsqueeze(-1)  # 幅度, 即V
print(f"linear.weight: {linear.weight.shape}")  # 权重矩阵
print(f"weight_magnitude: {weight_magnitude.shape}")  weight_direction = linear.weight / weight_magnitude  # 单位向量矩阵print(f"linear.weight:\\n{linear.weight}")
print(f"weight_magnitude:\\n{weight_magnitude}")
print(f"weight_direction:\\n{weight_direction}")# 单位向量平方求和是1
print(f"magnitude of weight_direction:\\n{(weight_direction2).sum(dim=-1)}")linear_weight = weight_direction*weight_magnitude
print(f"weight_direction*weight_magnitude:\\n{linear_weight}")# 第一种
print(f"inputx @ w_norm.T:\\n{inputx @ linear_weight.T}")# 第二种
print(f"linear(inputx):\\n{linear(inputx)}")# 第三种,新的module不会改变输出值,参数由1个weight,变成weight_g和weight_v
print(f"wn_linear(inputx):\\n{wn_linear(inputx)}")print(f"parameters of wn_linear:")
for n,p in wn_linear.named_parameters():print(n, p)print("construct weight of linear:")
wn_direction = [wn_linear.weight_v[i, :].norm() for i in torch.arange(wn_linear.weight_v.shape[0])]
wn_weight = wn_linear.weight_g*(wn_linear.weight_v/torch.tensor(weight_mag, dtype=torch.float).unsqueeze(-1))
print(wn_weight)# input: [B,C,T], weight: [Co, Ci, 1]
conv1d = nn.Conv1d(feat_dim, hid_dim, kernel_size=1, bias=False)
wn_conv1d = torch.nn.utils.weight_norm(conv1d)
print(f"wn_conv1d:\\n{wn_conv1d.weight.shape}")# 每1行计算幅度
conv1d_weight_magnitude = torch.tensor([conv1d.weight[i,:,:].norm() for i in torch.arange(conv1d.weight.shape[0])], dtype=torch.float).reshape(-1, 1, 1)
print(f"conv1d_weight_magnitude: {conv1d_weight_magnitude.shape}")
print(f"conv1d.weight: {conv1d.weight.shape}")
conv1d_weight_direction = conv1d.weight / conv1d_weight_magnitudeprint("parameters of wn_conv1d:")
for n, p in wn_conv1d.named_parameters():print(n, p, p.shape)
print("construct weight of conv1d:")
wn_conv1d_weight_unit = torch.tensor([wn_conv1d.weight_v[i,:,:].norm() for i in torch.arange(wn_conv1d.weight_v.shape[0])], dtype=torch.float).reshape(-1, 1, 1)
weight = wn_conv1d.weight_g * (wn_conv1d.weight_v / wn_conv1d_weight_unit)
print(weight)print(f"conv1d.weight:\\n{conv1d.weight}")
# print(f"conv1d_weight_magnitude:\\n{conv1d_weight_magnitude}")
# print(f"conv1d_weight_direction:\\n{conv1d_weight_direction}")