拒绝“死记硬背”:通俗详解机器学习中的权重衰减(Weight Decay)

导读:你是否遇到过模型在训练集上表现完美,一上测试集就“崩盘”的情况?这通常是过拟合在作祟。今天,我们将深入探讨深度学习中最经典、最有效的防过拟合技术之一——权重衰减(Weight Decay)。不用复杂的数学推导,我们用通俗的比喻和代码实战,带你彻底搞懂它。


🤔 为什么模型会“死记硬背”?

想象一下你在备考一场重要的数学考试:

  • 学生 A(过拟合):他把练习册上的每一道题连同答案都背得滚瓜烂熟,甚至连题目的印刷错误都记住了。结果考试时,老师稍微改了个数字,他就彻底懵了。
  • 学生 B(泛化能力强):他没有死记硬背,而是努力理解背后的公式和逻辑。不管题目怎么变,他都能举一反三。

在机器学习中,**过拟合(Overfitting)**就是模型变成了“学生 A”。它过度学习了训练数据中的噪声和细节,导致在面对新数据时表现糟糕。

为了解决这个问题,我们需要给模型立个规矩:“你可以学习规律,但不许把权重搞得太大!” 这就是权重衰减的核心思想。


💡 什么是权重衰减?

1. 核心直觉

在线性模型中,预测公式通常是 $y = w_1x_1 + w_2x_2 + … + b$。这里的 $w$ 就是权重

  • 如果某个权重 $w$ 特别大,说明模型过度依赖某一个特征。这往往是模型在强行拟合数据中的噪声(即“死记硬背”)。
  • 权重衰减通过惩罚那些过大的权重,迫使模型保持“谦虚”,让参数值趋向于更小、更平滑。模型越简单,通常泛化能力越强。

2. 数学原理(简单版)

通常我们训练模型的目标是最小化预测误差(Loss)
加入权重衰减后,目标函数变成了:

$$ \text{总损失} = \underbrace{\text{预测误差}}{\text{猜得准不准}} + \underbrace{\lambda \times \sum w^2}{\text{模型复杂度惩罚}} $$

  • $\sum w^2$:所有权重的平方和。权重越大,这项越大。
  • $\lambda$ (Lambda):一个调节旋钮(超参数)。
    • $\lambda = 0$:完全不惩罚,模型可能过拟合。
    • $\lambda$ 很大:严厉禁止大权重,模型可能变得太简单(欠拟合)。
    • $\lambda$ 适中:在“准确”和“简单”之间找到最佳平衡点。

3. 为什么叫“衰减”?

在每次更新参数时,算法不仅会根据误差调整权重,还会强制让权重乘以一个小于 1 的系数。
公式大致如下:
$$ w_{\text{new}} = (1 - \text{学习率} \times \lambda) \times w_{\text{old}} - \text{学习率} \times \text{梯度} $$

注意前面的 $(1 - \text{学习率} \times \lambda)$ 这一项,它永远小于 1。这意味着,每一步更新,权重都会先自动“缩水”一点点,然后再根据误差进行调整。久而久之,权重就会维持在一个较小的范围内。


💻 代码实战:从原理到应用

为了演示效果,我们构造一个典型的过拟合场景:

  • 真实规律:很简单。
  • 陷阱:给了模型 200 个特征(大部分是噪声),但只有 20 个训练样本
  • 结果:如果不加限制,模型一定会过拟合。

我们将展示两种实现方式:从零手动实现(理解原理)和 框架简洁实现(实际开发)。

方法一:从零开始实现(手动添加惩罚项)

这种方式最直观,我们在计算 Loss 时,手动加上权重的平方和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torch import nn
from d2l import torch as d2l # 假设使用动手学深度学习的数据加载工具

# 1. 定义 L2 范数惩罚函数
def l2_penalty(w):
"""计算权重的平方和除以2"""
return torch.sum(w.pow(2)) / 2

def train_scratch(lambd):
# 初始化参数:200个特征,容易过拟合
w = torch.normal(0, 1, size=(200, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

lr, num_epochs = 0.003, 100

for epoch in range(num_epochs):
for X, y in train_iter: # train_iter 需预先定义
# --- 核心步骤 ---
# 1. 前向传播
y_hat = X.matmul(w) + b
# 2. 计算原始 MSE 误差
mse_loss = ((y_hat - y) ** 2).mean()
# 3. 计算总损失 = 原始误差 + lambda * 惩罚项
l = mse_loss + lambd * l2_penalty(w)
# ----------------

l.backward()

# 4. 更新参数 (SGD)
with torch.no_grad():
w -= lr * w.grad
b -= lr * b.grad
w.grad.zero_()
b.grad.zero_()

print(f"[从零实现] Lambda={lambd:.1f} | 训练误差: {mse_loss.item():.4f} | w的L2范数: {torch.norm(w).item():.4f}")

# 运行测试
print("--- 不使用权重衰减 (λ=0) ---")
train_scratch(0) # 预期:训练误差极低,但w的范数很大(过拟合)

print("\n--- 使用权重衰减 (λ=3) ---")
train_scratch(3) # 预期:训练误差稍高,但w的范数显著变小(泛化更好)

代码解读:

  • lambd=0 时,模型为了迎合那 20 个样本的噪声,把权重 $w$ 养得非常大。
  • lambd=3 时,巨大的权重会被严厉惩罚,迫使 $w$ 保持较小。虽然训练集误差稍微变大(因为不能完美拟合噪声了),但这正是我们想要的——模型学会了忽略噪声,关注主要规律。

方法二:简洁实现(使用框架内置功能)

在实际工程中(如 PyTorch, TensorFlow),我们不需要手动写 + lambd * l2_penalty(w)。优化器(Optimizer)内部已经集成了这个功能,只需设置 weight_decay 参数即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def train_concise(wd):
# 定义网络
net = nn.Sequential(nn.Linear(200, 1))

# 初始化参数
for param in net.parameters():
param.data.normal_()

loss = nn.MSELoss(reduction='none')
lr = 0.003

# --- 核心步骤 ---
# 在优化器中直接指定 weight_decay (即 lambda)
# 最佳实践:通常只对权重(weight)进行衰减,不对偏置(bias)进行衰减
trainer = torch.optim.SGD([
{"params": net[0].weight, "weight_decay": wd},
{"params": net[0].bias} # bias 默认 weight_decay=0
], lr=lr)
# ----------------

num_epochs = 100
for epoch in range(num_epochs):
for X, y in train_iter:
trainer.zero_grad()
l = loss(net(X), y)
l.mean().backward()
# step() 内部自动执行了:w = (1 - lr*wd)*w - lr*gradient
trainer.step()

w_norm = net[0].weight.norm().item()
print(f"[简洁实现] Lambda={wd:.1f} | w的L2范数: {w_norm:.4f}")

# 运行测试
train_concise(0)
train_concise(3)

💡 小贴士

  • weight_decay 参数就是公式里的 $\lambda$。
  • 这种做法不仅代码更简洁,而且计算效率更高,因为框架在底层进行了优化。
  • 为什么不对 Bias 做衰减? 偏置项 $b$ 只是整体平移输出,不会像权重那样放大输入特征的噪声,因此通常不需要正则化。

🎯 总结与最佳实践

  1. 目的明确:权重衰减是防止过拟合的利器,它通过限制模型复杂度来提高泛化能力。
  2. 核心机制:在损失函数中加入 $L_2$ 惩罚项,或在参数更新时强制权重“缩水”。
  3. 调参指南
    • $\lambda$ (weight_decay) 是关键。
    • 如果验证集误差远大于训练集误差 $\rightarrow$ 增大 $\lambda$。
    • 如果训练集和验证集误差都很高 $\rightarrow$ 减小 $\lambda$ 或检查其他问题。
    • 常用范围:$10^{-5}$ 到 $10^{-2}$ 之间,具体取决于任务和数据集。
  4. 工程落地:在现代深度学习框架中,直接在优化器(如 SGD, AdamW)中设置 weight_decay 参数即可。推荐使用 AdamW 而不是 Adam,因为 AdamW 修正了 Adam 中权重衰减的实现方式,效果通常更好。

最后的话:机器学习不仅仅是让模型在已知数据上考满分,更重要的是让它具备“举一反三”的能力。权重衰减,就是那个时刻提醒模型“不要骄傲,保持简单”的良师益友。


希望这篇博客能帮你彻底理解权重衰减!如果你觉得有用,欢迎分享给更多的小伙伴。🚀