从零开始手写 Softmax 回归:理解多分类背后的魔法

你有没有想过,像 Gmail 自动把邮件分到“社交”、“促销”或“论坛”这样的功能,背后是怎么实现的?或者手机相册如何自动识别照片里是“猫”还是“狗”?这些都属于多分类问题——而 Softmax 回归,正是解决这类问题最基础、最重要的工具之一。

今天,我们就抛开现成的深度学习框架(比如 torch.nn.Linear + CrossEntropyLoss),从零开始一行行实现一个完整的 Softmax 回归模型,用它来识别 Fashion-MNIST 数据集中的衣服图片(T恤、裤子、鞋子等10类)。这不仅是一次编程练习,更是一次深入理解机器学习核心机制的旅程。


🧩 为什么叫 “Softmax”?它和“硬预测”有什么关系?

想象一下,模型对一张图片的判断是:

  • T恤: 10%
  • 裤子: 5%
  • 套头衫: 80%
  • 连衣裙: 3%
  • 凉鞋: 2%

虽然模型内心有“偏好”,但最终我们必须给用户一个明确答案:“这是套头衫”。这种必须选一个类别的决策,叫做 硬预测(Hard Prediction)

Softmax 的作用,就是把模型原始输出的任意数字(比如 [2.1, -0.5, 3.8, ...]转换成一个合法的概率分布——所有概率非负,且加起来等于1。这样我们才能说“套头衫的概率是80%”。

💡 小知识:Gmail 内部可能计算了每个类别的概率,但最终展示时必须“硬”选一个文件夹。这就是 Softmax + argmax 的典型应用。

不过,这里有个值得思考的问题:返回概率最大的分类标签总是最优解吗?在大多数日常场景(如邮件分类、图像标签)中,这样做没问题。但在医疗诊断这类高风险场景下,即使模型认为“癌症”的概率不是最高(比如只有40%),也可能需要医生进一步检查——此时直接做“硬预测”就可能掩盖重要风险。因此,是否做硬预测,取决于应用场景对不确定性的容忍度


🛠️ 第一步:准备数据 —— 把图片变向量

Fashion-MNIST 数据集包含 7 万张 28×28 像素的灰度图,共10类衣物。我们先加载数据,并设置批量大小为256:

1
2
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

每张图原本是 28×28 的矩阵,我们把它展平成长度为 784 的向量。虽然这会丢失空间结构信息(后面卷积网络会解决),但对线性模型来说足够了。


⚙️ 第二步:初始化模型参数

Softmax 回归本质上是一个线性模型

  • 输入:784 维向量(图片像素)
  • 输出:10 个数(每个类别的“得分”)

所以需要:

  • 权重矩阵 W:形状 (784, 10)
  • 偏置向量 b:形状 (10,)

我们用均值为0、标准差为0.01的正态分布初始化 W,b 初始化为0:

1
2
3
4
5
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

🔍 注意:requires_grad=True 是为了让 PyTorch 自动计算梯度,方便后续更新参数。


🔥 第三步:实现 Softmax 函数 —— 把分数变概率

Softmax 公式长这样:

$$
\mathrm{softmax}(\mathbf{X}){ij} = \frac{\exp(\mathbf{X}{ij})}{\sum_k \exp(\mathbf{X}_{ik})}
$$

翻译成人话就是:

  1. 对每个元素取指数(exp)→ 让所有数变正
  2. 对每一行求和 → 得到该样本的“总能量”
  3. 每个元素除以该行总和 → 归一化成概率

代码实现(以 PyTorch 为例):

1
2
3
4
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(dim=1, keepdim=True) # 按行求和,保持维度
return X_exp / partition # 广播机制自动除

✅ 验证:输出的每一行加起来都是1!

然而,这里隐藏着一个严重的数值稳定性问题。试想:如果某个神经元输出特别大,比如 50,那么 exp(50) 会远远超出计算机能表示的范围,变成无穷大(inf),导致整个 Softmax 输出变成 NaNinf。类似地,在计算交叉熵损失时,如果预测概率为0(比如由于下溢),log(0) 会变成负无穷,损失爆炸。

怎么解决?工程实践中常用一个技巧:减去每行的最大值。因为 Softmax 对输入加上任意常数是不变的(数学上可证明),所以我们可以安全地计算:

1
2
3
4
def stable_softmax(X):
X = X - X.max(dim=1, keepdim=True).values # 先减最大值
X_exp = torch.exp(X)
return X_exp / X_exp.sum(dim=1, keepdim=True)

这样,最大的指数项变成 exp(0) = 1,其他项都 ≤1,彻底避免了溢出。


🧠 第四步:定义模型 —— 线性变换 + Softmax

模型很简单:输入 X → 线性变换(XW + b)→ Softmax → 输出概率分布。

1
2
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

注意 X.reshape((-1, 784)) 是为了处理批量输入(比如一次256张图)。


❌ 第五步:定义损失函数 —— 交叉熵(Cross-Entropy)

我们希望:真实标签对应的预测概率越高越好。交叉熵损失正好衡量这一点:

$$
\text{Loss} = -\log(p_{\text{true class}})
$$

比如真实标签是“套头衫”(索引2),而模型预测 p[2] = 0.8,那么损失就是 -log(0.8) ≈ 0.22;如果 p[2] = 0.1,损失就变成 2.3 —— 惩罚更大!

关键技巧:如何高效取出“真实标签对应的概率”?

假设:

  • y = [0, 2] 表示两个样本的真实标签(第0类和第2类)
  • y_hat = [[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]] 是预测概率

我们想取出 y_hat[0, 0]y_hat[1, 2],即 [0.1, 0.5]

PyTorch 中可以用高级索引:

1
y_hat[[0, 1], y]  # 等价于 [y_hat[0, y[0]], y_hat[1, y[1]]]

于是交叉熵损失函数为:

1
2
def cross_entropy(y_hat, y):
return -torch.log(y_hat[range(len(y_hat)), y])

⚠️ 注意:如果 y_hat 中某个位置的概率为0(理论上不会发生,但数值误差可能导致极小值),log(0) 会出错。实际框架(如 PyTorch 的 CrossEntropyLoss)内部会加一个极小值(如 1e-8)防止这种情况。


✅ 第六步:计算分类精度 —— 我们到底对了多少?

精度 = 正确预测数 / 总样本数。

步骤:

  1. y_hat 每行取 argmax → 得到预测类别
  2. 和真实标签 y 比较 → 得到布尔数组(对/错)
  3. 求和 → 正确数量
1
2
3
4
5
def accuracy(y_hat, y):
if y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1) # 取概率最大的类别
cmp = (y_hat.type(y.dtype) == y) # 类型要一致!
return float(cmp.sum())

📌 重要提醒:精度虽然直观,但不可导,不能直接用来训练模型。所以我们用交叉熵损失来优化,用精度来评估。这也是为什么我们常说“优化的是损失,关心的是精度”。


🚀 第七步:训练模型 —— 小批量随机梯度下降

训练循环非常经典:

  1. 前向传播:计算预测 y_hat
  2. 计算损失 l = cross_entropy(y_hat, y)
  3. 反向传播:l.backward() 自动计算梯度
  4. 更新参数:用 SGD(随机梯度下降)

我们复用之前线性回归的 sgd 函数:

1
2
def updater(batch_size):
return d2l.sgd([W, b], lr=0.1, batch_size=batch_size)

然后跑10个epoch:

1
2
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

你会看到训练损失下降,训练和测试精度上升(最终约80%+)!

🔍 如何理解训练过程中的三个关键指标?

在训练过程中,我们会同时监控以下三个指标:

1️⃣ train loss(训练损失)

  • 含义:模型在训练数据上计算出的平均交叉熵损失值
  • 作用:衡量模型对训练样本的“拟合程度”。
    • 损失越小 → 模型对训练数据的预测越接近真实标签。
    • 训练过程中,我们通过优化算法(如SGD)不断降低这个值
  • 注意:如果 train loss 降得很低,但测试精度不高,可能说明模型过拟合了(死记硬背训练数据,泛化能力差)。

✅ 在你的实验中,train loss 从高到低下降(比如从 1.0 降到 0.3),说明模型正在“学会”分类。

2️⃣ train acc(训练精度)

  • 含义:模型在训练数据上的分类准确率(即预测正确的比例)。
  • 计算方式
    正确预测的训练样本数 / 总训练样本数
  • 意义:反映模型在“见过的数据”上表现如何。
    • 初期较低(比如 0.6),随着训练逐渐上升(比如到 0.85)。
  • 局限性:不能完全代表模型真实能力,因为模型已经“看过”这些数据。

3️⃣ test acc(测试精度)

  • 含义:模型在从未见过的测试数据上的分类准确率
  • 为什么重要
    这才是衡量模型泛化能力(generalization)的关键指标!
    • 它回答了这个问题:“模型学到的知识,能用到新图片上吗?”
  • 理想情况test acc 接近 train acc(比如 train acc=85%,test acc=83%)→ 模型学得扎实。
  • 危险信号train acc 很高(95%),但 test acc 很低(70%)→ 过拟合

✅ 在你的 Fashion-MNIST 实验中,最终 test acc 达到 80%+,说明这个简单的 Softmax 回归模型已经能有效识别衣物类别!

  • 趋势:随着训练进行,损失下降,两个精度都上升。
  • 关键观察test acc 始终略低于 train acc,这是正常现象;但如果差距越来越大,就要警惕过拟合。

💡 一句话总结
train loss 是优化目标,train acc 是训练表现,test acc 才是真实水平。


🔮 第八步:预测看看!

训练完后,我们可以让模型“看图说话”:

1
predict_ch3(net, test_iter)

🧠 小结:Softmax 回归的核心思想

步骤 关键操作 目的
输入 图片 → 784维向量 标准化输入
模型 线性变换 + Softmax 输出合法概率分布
损失 交叉熵 惩罚“真实类别概率低”的预测
优化 SGD 最小化损失,提高精度
评估 精度(argmax + 比较) 衡量实际性能

💡 惊人事实:这个看似简单的模型,只用了线性运算 + 一个非线性激活(Softmax),就在10分类任务上达到了80%+的准确率!

不过,当我们把 Softmax 回归用于更复杂的任务时,比如预测下一个单词,就会遇到新挑战。自然语言中词汇量可能高达几十万甚至上百万。如果直接用 Softmax 输出每个词的概率,就需要一个超大的权重矩阵(比如 512 × 100,000),不仅内存爆炸,计算也极其缓慢。这就是所谓的“大词汇表问题”。解决方案包括:层次 Softmax负采样词嵌入共享等技术——它们都是现代语言模型(如 Word2Vec、BERT)的重要组成部分。


🎉 结语

通过从零实现 Softmax 回归,我们不仅掌握了多分类的基本原理,还亲手构建了数据加载、模型定义、损失计算、训练循环、评估预测的完整 pipeline。更重要的是,我们在实现过程中不断思考:数值稳定性如何保证硬预测是否总是合理大规模输出如何处理?这些问题的答案,正是从“会调库”走向“懂原理”的关键一步。

下次当你看到 AI 自动分类邮件、识别物体、翻译语言时,不妨想想:背后可能就运行着这样一个优雅而强大的 Softmax 回归(或它的深度扩展)!