机器学习笔记之正则化(三)权重衰减角度(偏差方向)
引言
上一节从直观现象的角度观察权重W\\mathcal WW是如何出现权重衰减的,并且介绍了W\\mathcal WW的权重衰减是如何抑制过拟合的发生。本节从偏差方向观察权重衰减。
回顾:关于目标函数中的λ,C\\lambda,\\mathcal Cλ,C
回顾基于拉格朗日乘数法角度的优化问题表示如下:
这里以
L2L_2L2正则化为例。
{L(W,λ)=J(W)+λ(∣∣W∣∣2−C)s.t.λ>0\\begin{cases} \\mathcal L(\\mathcal W,\\lambda) = \\mathcal J(\\mathcal W) + \\lambda (||\\mathcal W||_2 - \\mathcal C) \\\\ s.t. \\quad \\lambda > 0 \\end{cases}{L(W,λ)=J(W)+λ(∣∣W∣∣2−C)s.t.λ>0
将目标函数L(W,λ)\\mathcal L(\\mathcal W,\\lambda)L(W,λ)展开,可得到如下形式:
L(W,λ)=J(W)+λ∣∣W∣∣2⏟标准正则化−λ⋅C\\mathcal L(\\mathcal W,\\lambda) = \\underbrace{\\mathcal J(\\mathcal W) + \\lambda ||\\mathcal W||_2}_{标准正则化} - \\lambda \\cdot \\mathcal CL(W,λ)=标准正则化J(W)+λ∣∣W∣∣2−λ⋅C
其中C=WTW\\mathcal C = \\sqrt{\\mathcal W^T\\mathcal W}C=WTW,也就是L2L_2L2正则化范围的半径;而λ\\lambdaλ表示拉格朗日参数。
关于λ,C\\lambda,\\mathcal Cλ,C,无论我们修改哪一个参数,最终都会影响正则化在权重空间中的 有效范围:
-
当λ\\lambdaλ确定,C\\mathcal CC发生变化时,由于λ⋅C\\lambda \\cdot \\mathcal Cλ⋅C项的目的就是为了在梯度下降每次迭代过程中,找到一个大小相等、方向相反的向量。λ⋅C\\lambda \\cdot \\mathcal Cλ⋅C就是调节大小的。
λ\\lambdaλ确定条件下,仅在当前迭代步骤中,C\\mathcal CC也被确定,从而正则化范围被确定。那么这个大小相等、方向相反的向量在权重空间中一定时当前正则化范围中的最优解。这个解大概率会在范围的边缘上,与J(W)\\mathcal J(\\mathcal W)J(W)在权重空间中的等高线相切;
-
当C\\mathcal CC确定,λ\\lambdaλ不确定时,也就是说权重范围被确定,我们假设通过不断尝试不同的λ\\lambdaλ,使得它与当前迭代步骤的梯度向量大小相等、方向相反。这个值它仅能保证在正则化范围中。但不能确定它是当前迭代步骤的最优解——可能在正则化范围内存在权重点,它距离J(W)\\mathcal J(\\mathcal W)J(W)的等高线中心更近。
在深度学习框架中,如PyTorch\\text{PyTorch}PyTorch。它的处理方式是第一种,我们手动设置λ\\lambdaλ具体值,让神经网络自己调节正则化范围C\\mathcal CC:
这里以
Adam\\text{Adam}Adam算法为例。这里不深究,仅作一个关于正则化参数的描述。
from torch import optim as optimOptimizer = optim.Adam(TrainModel.parameters(), lr=LearningRate, weight_decay=WeightDecay)
其中weight_decay
参数就是参数λ\\lambdaλ的取值。在模型训练过程中,每一次反向传播过程我们会更新W\\mathcal WW的信息,自然也会更新C=WTW\\mathcal C = \\sqrt{\\mathcal W^T\\mathcal W}C=WTW的信息。
正则化与非正则化之间的偏差
下图表示的是权重空间中某损失函数J(W)\\mathcal J(\\mathcal W)J(W)的等高线,而虚线表示某一梯度方向。
上述箭头的指向表示权重向最优权重优化的过程。值得注意的是,由于初始化权重Winit\\mathcal W_{init}Winit是随机的,这意味着虚线/箭头不是唯一确定的。
并且并不是箭头最终指向的点就是最优解。实际上,在迭代过程中,只要Winit\\mathcal W_{init}Winit被确定下来,那么该Winit\\mathcal W_{init}Winit到J(W)\\mathcal J(\\mathcal W)J(W)等高线中心方向上的任意一点,都是对应每次迭代中的最优解。
这里依然以L1,L2L_1,L_2L1,L2正则化为例。假设我们的梯度优化方向就是虚线方向,那么观察迭代过程中加入L1,L2L_1,L_2L1,L2正则化后最优解的路径与未加入正则化情况下最优解的期望逻辑之间的关系:
这里
蓝色点表示添加正则化后的最优解路径;
红色点表示未加入正则化后最优解的期望路径。
该图片来自于文章下方的视频链接,下同,侵删。
通过观察图像可以发现:无论是L1L_1L1还是L2L_2L2正则化,它们的路径均与期望路径之间存在偏差。
这个偏差如何计算??? 并且这个偏差对于权重衰减有什么关联关系???
偏差的计算过程
这里假设W∗\\mathcal W^*W∗表示损失函数J(W)\\mathcal J(\\mathcal W)J(W)条件下产生的最优解。那么J(W∗)\\mathcal J(\\mathcal W^*)J(W∗)就表示该最优解对应的损失函数的最优解;
最优解表示‘最小训练误差’时的权重。
而J^(W)\\hat J(\\mathcal W)J^(W)表示正则化后的损失函数,并假设该函数在W^\\hat {\\mathcal W}W^中取得最优解J^(W^)\\hat J(\\hat {\\mathcal W})J^(W^)。我们需要讨论的是:W∗\\mathcal W^*W∗和W^\\hat {\\mathcal W}W^之间的偏差是多少。从数学角度观察,我们希望通过公式表达W∗\\mathcal W^*W∗和W^\\hat {\\mathcal W}W^之间的函数关系/关联关系。
-
我们通过泰勒公式将损失函数J(W)\\mathcal J(\\mathcal W)J(W)进行近似表达。这里仅将其扩展至二次:
与
泰勒公式的标准式进行比较,这里将常数
a0=W∗a_0 = \\mathcal W^*a0=W∗,真正的变量只有
W\\mathcal WW.
J(W)≈J(W∗)+11!∇WJ(W∗)⋅(W−W∗)+12!J′′(W∗)⋅(W−W∗)2=J(W∗)+∇WJ(W∗)+12(W−W∗)TH(W−W∗)\\begin{aligned} \\mathcal J(\\mathcal W) & \\approx \\mathcal J(\\mathcal W^*) + \\frac{1}{1!}\\nabla_{\\mathcal W} \\mathcal J(\\mathcal W^*) \\cdot (\\mathcal W - \\mathcal W^*) + \\frac{1}{2!} \\mathcal J''(\\mathcal W^*) \\cdot (\\mathcal W - \\mathcal W^*)^2 \\\\ & = \\mathcal J(\\mathcal W^*) + \\nabla_{\\mathcal W} \\mathcal J(\\mathcal W^*) + \\frac{1}{2} (\\mathcal W - \\mathcal W^*)^T \\mathcal H (\\mathcal W - \\mathcal W^*) \\end{aligned}J(W)≈J(W∗)+1!1∇WJ(W∗)⋅(W−W∗)+2!1J′′(W∗)⋅(W−W∗)2=J(W∗)+∇WJ(W∗)+21(W−W∗)TH(W−W∗)
其中H\\mathcal HH表示Hession\\text{Hession}Hession矩阵,它表示损失函数关于W∗\\mathcal W^*W∗的二阶导结果。上式中,由于J(W∗)\\mathcal J(\\mathcal W^*)J(W∗)是最值点,因而∇WJ(W∗)=0\\nabla_{\\mathcal W} \\mathcal J(\\mathcal W^*) = 0∇WJ(W∗)=0。最终可将上式写成如下形式:
J(W)≈J(W∗)+12(W−W∗)TH(W−W∗)\\mathcal J(\\mathcal W) \\approx \\mathcal J(\\mathcal W^*) +\\frac{1}{2} (\\mathcal W - \\mathcal W^*)^T \\mathcal H (\\mathcal W - \\mathcal W^*)J(W)≈J(W∗)+21(W−W∗)TH(W−W∗) -
对上述损失函数J(W)\\mathcal J(\\mathcal W)J(W)关于变量W\\mathcal WW求解梯度:
由于
W∗\\mathcal W^*W∗已知,那么
J(W∗)\\mathcal J(\\mathcal W^*)J(W∗)也是常数,其梯度为
000.这里用到矩阵求导公式
.
∂∂W[(W−W∗)TH(W−W∗)]=∂∂(W−W∗)[(W−W∗)TH(W−W∗)]⋅∂(W−W∗)∂W=2⋅H(W−W∗)⋅(1−0)=2⋅H(W−W∗)∇WJ(W)=∇WJ(W∗)⏟=0+∇W[12(W−W∗)TH(W−W∗)]=12⋅2⋅H(W−W∗)=H(W−W∗)\\begin{aligned} \\frac{\\partial}{\\partial \\mathcal W} \\left[(\\mathcal W - \\mathcal W^*)^T \\mathcal H (\\mathcal W - \\mathcal W^*)\\right] & = \\frac{\\partial}{\\partial (\\mathcal W - \\mathcal W^*)}\\left[(\\mathcal W - \\mathcal W^*)^T \\mathcal H (\\mathcal W - \\mathcal W^*)\\right] \\cdot \\frac{\\partial (\\mathcal W - \\mathcal W^*)}{\\partial \\mathcal W}\\\\ & = 2 \\cdot \\mathcal H(\\mathcal W - \\mathcal W^*) \\cdot (1 - 0) \\\\ & = 2 \\cdot \\mathcal H(\\mathcal W - \\mathcal W^*)\\\\ \\nabla_{\\mathcal W} \\mathcal J(\\mathcal W) & = \\underbrace{\\nabla_{\\mathcal W} \\mathcal J(\\mathcal W^*)}_{=0} + \\nabla_{\\mathcal W} \\left[\\frac{1}{2} (\\mathcal W - \\mathcal W^*)^T \\mathcal H (\\mathcal W - \\mathcal W^*)\\right] \\\\ & = \\frac{1}{2} \\cdot 2 \\cdot \\mathcal H(\\mathcal W - \\mathcal W^*) \\\\ & = \\mathcal H(\\mathcal W - \\mathcal W^*) \\end{aligned}∂W∂[(W−W∗)TH(W−W∗)]∇WJ(W)=∂(W−W∗)∂[(W−W∗)TH(W−W∗)]⋅∂W∂(W−W∗)=2⋅H(W−W∗)⋅(1−0)=2⋅H(W−W∗)==0∇WJ(W∗)+∇W[21(W−W∗)TH(W−W∗)]=21⋅2⋅H(W−W∗)=H(W−W∗)
-
上述结果是未使用正则化条件下关于W\\mathcal WW的梯度结果。这里依然以L2L_2L2正则化为例,观察正则化后的损失函数J^(W)\\hat {\\mathcal J}(\\mathcal W)J^(W)的结果以及梯度结果表示为如下形式:
{J^(W)=J(W∗)+12(W−W∗)TH(W−W∗)+α2WTW⏟正则化项∇WJ^(W)=H(W−W∗)+α2⋅2⋅W=H(W−W∗)+α⋅W\\begin{cases} \\begin{aligned} \\hat {\\mathcal J}(\\mathcal W) = \\mathcal J(\\mathcal W^*) + \\frac{1}{2} (\\mathcal W - \\mathcal W^*)^T \\mathcal H(\\mathcal W - \\mathcal W^*) + \\underbrace{\\frac{\\alpha}{2} \\mathcal W^T\\mathcal W}_{正则化项} \\end{aligned} \\\\ \\begin{aligned} \\nabla_{\\mathcal W} \\hat {\\mathcal J}(\\mathcal W) & = \\mathcal H(\\mathcal W - \\mathcal W^*) + \\frac{\\alpha}{2} \\cdot 2 \\cdot \\mathcal W \\\\ & = \\mathcal H(\\mathcal W - \\mathcal W^*) + \\alpha \\cdot \\mathcal W \\end{aligned} \\end{cases}⎩⎨⎧J^(W)=J(W∗)+21(W−W∗)TH(W−W∗)+正则化项2αWTW∇WJ^(W)=H(W−W∗)+2α⋅2⋅W=H(W−W∗)+α⋅W -
由于W^\\hat {\\mathcal W}W^是使损失函数J^(W)\\hat {\\mathcal J}(\\mathcal W)J^(W)取得最小值的解。这意味着∇WJ^(W)∣W=W^=0\\nabla_{\\mathcal W} \\hat {\\mathcal J}(\\mathcal W) |_{\\mathcal W = \\hat {\\mathcal W}} = 0∇WJ^(W)∣W=W^=0。则有:
这里
I\\mathcal II表示单位矩阵。
H(W^−W∗)+α⋅W^=0⇒(H+α⋅I)W^=H⋅W∗⇒W^=(H+α⋅I)−1H⋅W∗\\begin{aligned} & \\quad \\mathcal H(\\hat {\\mathcal W} - \\mathcal W^*) + \\alpha \\cdot \\hat {\\mathcal W} = 0 \\\\ & \\Rightarrow(\\mathcal H + \\alpha \\cdot \\mathcal I) \\hat {\\mathcal W} = \\mathcal H \\cdot \\mathcal W^* \\\\ & \\Rightarrow \\hat {\\mathcal W} = (\\mathcal H + \\alpha \\cdot \\mathcal I)^{-1} \\mathcal H \\cdot \\mathcal W^* \\end{aligned}H(W^−W∗)+α⋅W^=0⇒(H+α⋅I)W^=H⋅W∗⇒W^=(H+α⋅I)−1H⋅W∗ -
关于Hession\\text{Hession}Hession矩阵的性质,如果J(W)\\mathcal J(\\mathcal W)J(W)在权重空间内连续可导,那么H\\mathcal HH是一个对阵矩阵。对H\\mathcal HH进行特征值分解:H=QΛQT\\mathcal H = \\mathcal Q \\Lambda \\mathcal Q^TH=QΛQT,并将该结果代入到上式中:
其中
Q\\mathcal QQ是一个正交矩阵;
Λ\\LambdaΛ是对角矩阵,其对角线上元素是
H\\mathcal HH的特征值。
关于
Q\\mathcal QQ的性质,这里用到的有:
QT=Q−1⇒QTQ=QQT=1\\mathcal Q^T = \\mathcal Q^{-1} \\Rightarrow \\mathcal Q^T\\mathcal Q = \\mathcal Q \\mathcal Q^T = 1QT=Q−1⇒QTQ=QQT=1- α⋅I\\alpha \\cdot \\mathcal Iα⋅I
同样也是一个对称矩阵,同样也可以进行特征值分解。该部分在
线性回归——岭回归中存在相似的步骤。
上述公式对应《深度学习(花书)》P143 7.1 参数范围惩罚 7.13
W^=(QΛQT+α⋅I)−1QΛQT⋅W∗=[QΛQT+Q(α⋅I)QT]−1⋅W∗=[Q(Λ+α⋅I)QT]−1QΛQT⋅W∗=(QT)−1⏟=Q(Λ+α⋅I)−1Q−1Q⏟=IΛQT⋅W∗=Q(Λ+α⋅I)−1ΛQT⋅W∗\\begin{aligned} \\hat {\\mathcal W} & = (\\mathcal Q \\Lambda \\mathcal Q^T + \\alpha \\cdot \\mathcal I)^{-1} \\mathcal Q \\Lambda\\mathcal Q^T \\cdot \\mathcal W^* \\\\ & = \\left[\\mathcal Q \\Lambda \\mathcal Q^T + \\mathcal Q(\\alpha \\cdot \\mathcal I) \\mathcal Q^T\\right]^{-1} \\cdot \\mathcal W^* \\\\ & = [\\mathcal Q (\\Lambda + \\alpha \\cdot \\mathcal I) \\mathcal Q^T]^{-1} \\mathcal Q \\Lambda \\mathcal Q^T \\cdot \\mathcal W^* \\\\ & = \\underbrace{(\\mathcal Q^T)^{-1}}_{=\\mathcal Q}(\\Lambda + \\alpha \\cdot \\mathcal I)^{-1} \\underbrace{\\mathcal Q^{-1} \\mathcal Q}_{=\\mathcal I} \\Lambda \\mathcal Q^T \\cdot \\mathcal W^* \\\\ & = \\mathcal Q(\\Lambda + \\alpha \\cdot \\mathcal I)^{-1} \\Lambda \\mathcal Q^T \\cdot \\mathcal W^* \\end{aligned}W^=(QΛQT+α⋅I)−1QΛQT⋅W∗=[QΛQT+Q(α⋅I)QT]−1⋅W∗=[Q(Λ+α⋅I)QT]−1QΛQT⋅W∗==Q(QT)−1(Λ+α⋅I)−1=IQ−1QΛQT⋅W∗=Q(Λ+α⋅I)−1ΛQT⋅W∗
-
观察上式中的Λ+α⋅I\\Lambda + \\alpha \\cdot \\mathcal IΛ+α⋅I,由于Λ\\LambdaΛ和α⋅I\\alpha \\cdot \\mathcal Iα⋅I都是对角阵,那么它们相加依然是对角阵。对应的逆矩阵(Λ+α⋅I)−1(\\Lambda + \\alpha \\cdot \\mathcal I)^{-1}(Λ+α⋅I)−1就是主对角阵各元素取倒数的结果。由于后面乘了一个Λ\\LambdaΛ,那么取倒数后的分子就是Λ\\LambdaΛ中的各元素。至此,W^\\hat {\\mathcal W}W^与W∗\\mathcal W^*W∗之间满足如下关联关系:
这里我们不关心正交矩阵
Q\\mathcal QQ,它仅是一组正交基,无论在哪一组正交基下,不影响
W∗\\mathcal W^*W∗和
W^\\hat {\\mathcal W}W^之间的关系
.其中
W^i\\hat {\\mathcal W}_iW^i表示
W^\\hat {\\mathcal W}W^的第
iii个分量;同理,
Wi∗\\mathcal W_i^*Wi∗表示
W∗\\mathcal W^*W∗的第
iii个分量。
λi\\lambda_iλi表示对角矩阵
Λ\\LambdaΛ第
iii行第
iii列的元素。
W^i=λiλi+α⋅Wi∗\\hat {\\mathcal W}_i = \\frac{\\lambda_i}{\\lambda_i + \\alpha} \\cdot \\mathcal W_i^*W^i=λi+αλi⋅Wi∗
-
观察上式,显然W^\\hat {\\mathcal W}W^与W∗\\mathcal W^*W∗之间取决于α\\alphaα的结果:
需要注意的是,在
正则化——权重衰减(直观印象)中介绍过,
α\\alphaα不是的
λ\\lambdaλ,它内部包含了
WTW−C⇒WTW−C2\\sqrt{\\mathcal W^T\\mathcal W} - \\mathcal C \\Rightarrow \\mathcal W^T\\mathcal W - \\mathcal C^2WTW−C⇒WTW−C2的多余信息.因此
α\\alphaα的取值是不确定的。
- 当α=0\\alpha = 0α=0时,WTW\\mathcal W^T\\mathcal WWTW对应的系数为0⇒W^=W∗0 \\Rightarrow \\hat {\\mathcal W} = \\mathcal W^*0⇒W^=W∗。
- 当α>0\\alpha > 0α>0时,W^<W∗\\hat {\\mathcal W} < \\mathcal W^*W^<W∗;
- 当α<0\\alpha < 0α<0时,W^>W∗\\hat {\\mathcal W} > \\mathcal W^*W^>W∗。
可以看出,在L2L_2L2正则化中,本质上时通过α\\alphaα对W∗\\mathcal W^*W∗进行缩放得到的正则化权重结果,这也是权重衰减的本质。
相关参考:
《深度学习(花书)》7.1 参数范数惩罚
“L1和L2正则化”直观理解(之二),为什么又叫权重衰减?到底哪里衰减了?