从零开始手写 Softmax 回归
从零开始手写 Softmax 回归:理解多分类背后的魔法
你有没有想过,像 Gmail 自动把邮件分到“社交”、“促销”或“论坛”这样的功能,背后是怎么实现的?或者手机相册如何自动识别照片里是“猫”还是“狗”?这些都属于多分类问题——而 Softmax 回归,正是解决这类问题最基础、最重要的工具之一。
今天,我们就抛开现成的深度学习框架(比如 torch.nn.Linear + CrossEntropyLoss),从零开始一行行实现一个完整的 Softmax 回归模型,用它来识别 Fashion-MNIST 数据集中的衣服图片(T恤、裤子、鞋子等10类)。这不仅是一次编程练习,更是一次深入理解机器学习核心机制的旅程。
🧩 为什么叫 “Softmax”?它和“硬预测”有什么关系?
想象一下,模型对一张图片的判断是:
- T恤: 10%
- 裤子: 5%
- 套头衫: 80%
- 连衣裙: 3%
- 凉鞋: 2%
虽然模型内心有“偏好”,但最终我们必须给用户一个明确答案:“这是套头衫”。这种必须选一个类别的决策,叫做 硬预测(Hard Prediction)。
而 Softmax 的作用,就是把模型原始输出的任意数字(比如 [2.1, -0.5, 3.8, ...])转换成一个合法的概率分布——所有概率非负,且加起来等于1。这样我们才能说“套头衫的概率是80%”。
💡 小知识:Gmail 内部可能计算了每个类别的概率,但最终展示时必须“硬”选一个文件夹。这就是 Softmax + argmax 的典型应用。
不过,这里有个值得思考的问题:返回概率最大的分类标签总是最优解吗?在大多数日常场景(如邮件分类、图像标签)中,这样做没问题。但在医疗诊断这类高风险场景下,即使模型认为“癌症”的概率不是最高(比如只有40%),也可能需要医生进一步检查——此时直接做“硬预测”就可能掩盖重要风险。因此,是否做硬预测,取决于应用场景对不确定性的容忍度。
🛠️ 第一步:准备数据 —— 把图片变向量
Fashion-MNIST 数据集包含 7 万张 28×28 像素的灰度图,共10类衣物。我们先加载数据,并设置批量大小为256:
1 | batch_size = 256 |
每张图原本是 28×28 的矩阵,我们把它展平成长度为 784 的向量。虽然这会丢失空间结构信息(后面卷积网络会解决),但对线性模型来说足够了。
⚙️ 第二步:初始化模型参数
Softmax 回归本质上是一个线性模型:
- 输入:784 维向量(图片像素)
- 输出:10 个数(每个类别的“得分”)
所以需要:
- 权重矩阵 W:形状 (784, 10)
- 偏置向量 b:形状 (10,)
我们用均值为0、标准差为0.01的正态分布初始化 W,b 初始化为0:
1 | num_inputs = 784 |
🔍 注意:
requires_grad=True是为了让 PyTorch 自动计算梯度,方便后续更新参数。
🔥 第三步:实现 Softmax 函数 —— 把分数变概率
Softmax 公式长这样:
$$
\mathrm{softmax}(\mathbf{X}){ij} = \frac{\exp(\mathbf{X}{ij})}{\sum_k \exp(\mathbf{X}_{ik})}
$$
翻译成人话就是:
- 对每个元素取指数(
exp)→ 让所有数变正 - 对每一行求和 → 得到该样本的“总能量”
- 每个元素除以该行总和 → 归一化成概率
代码实现(以 PyTorch 为例):
1 | def softmax(X): |
✅ 验证:输出的每一行加起来都是1!
然而,这里隐藏着一个严重的数值稳定性问题。试想:如果某个神经元输出特别大,比如 50,那么 exp(50) 会远远超出计算机能表示的范围,变成无穷大(inf),导致整个 Softmax 输出变成 NaN 或 inf。类似地,在计算交叉熵损失时,如果预测概率为0(比如由于下溢),log(0) 会变成负无穷,损失爆炸。
怎么解决?工程实践中常用一个技巧:减去每行的最大值。因为 Softmax 对输入加上任意常数是不变的(数学上可证明),所以我们可以安全地计算:
1 | def stable_softmax(X): |
这样,最大的指数项变成 exp(0) = 1,其他项都 ≤1,彻底避免了溢出。
🧠 第四步:定义模型 —— 线性变换 + Softmax
模型很简单:输入 X → 线性变换(XW + b)→ Softmax → 输出概率分布。
1 | def net(X): |
注意 X.reshape((-1, 784)) 是为了处理批量输入(比如一次256张图)。
❌ 第五步:定义损失函数 —— 交叉熵(Cross-Entropy)
我们希望:真实标签对应的预测概率越高越好。交叉熵损失正好衡量这一点:
$$
\text{Loss} = -\log(p_{\text{true class}})
$$
比如真实标签是“套头衫”(索引2),而模型预测 p[2] = 0.8,那么损失就是 -log(0.8) ≈ 0.22;如果 p[2] = 0.1,损失就变成 2.3 —— 惩罚更大!
关键技巧:如何高效取出“真实标签对应的概率”?
假设:
y = [0, 2]表示两个样本的真实标签(第0类和第2类)y_hat = [[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]是预测概率
我们想取出 y_hat[0, 0] 和 y_hat[1, 2],即 [0.1, 0.5]。
PyTorch 中可以用高级索引:
1 | y_hat[[0, 1], y] # 等价于 [y_hat[0, y[0]], y_hat[1, y[1]]] |
于是交叉熵损失函数为:
1 | def cross_entropy(y_hat, y): |
⚠️ 注意:如果
y_hat中某个位置的概率为0(理论上不会发生,但数值误差可能导致极小值),log(0)会出错。实际框架(如 PyTorch 的CrossEntropyLoss)内部会加一个极小值(如1e-8)防止这种情况。
✅ 第六步:计算分类精度 —— 我们到底对了多少?
精度 = 正确预测数 / 总样本数。
步骤:
- 对
y_hat每行取argmax→ 得到预测类别 - 和真实标签
y比较 → 得到布尔数组(对/错) - 求和 → 正确数量
1 | def accuracy(y_hat, y): |
📌 重要提醒:精度虽然直观,但不可导,不能直接用来训练模型。所以我们用交叉熵损失来优化,用精度来评估。这也是为什么我们常说“优化的是损失,关心的是精度”。
🚀 第七步:训练模型 —— 小批量随机梯度下降
训练循环非常经典:
- 前向传播:计算预测
y_hat - 计算损失
l = cross_entropy(y_hat, y) - 反向传播:
l.backward()自动计算梯度 - 更新参数:用 SGD(随机梯度下降)
我们复用之前线性回归的 sgd 函数:
1 | def updater(batch_size): |
然后跑10个epoch:
1 | num_epochs = 10 |
你会看到训练损失下降,训练和测试精度上升(最终约80%+)!
🔍 如何理解训练过程中的三个关键指标?
在训练过程中,我们会同时监控以下三个指标:
1️⃣ train loss(训练损失)
- 含义:模型在训练数据上计算出的平均交叉熵损失值。
- 作用:衡量模型对训练样本的“拟合程度”。
- 损失越小 → 模型对训练数据的预测越接近真实标签。
- 训练过程中,我们通过优化算法(如SGD)不断降低这个值。
- 注意:如果
train loss降得很低,但测试精度不高,可能说明模型过拟合了(死记硬背训练数据,泛化能力差)。
✅ 在你的实验中,
train loss从高到低下降(比如从 1.0 降到 0.3),说明模型正在“学会”分类。
2️⃣ train acc(训练精度)
- 含义:模型在训练数据上的分类准确率(即预测正确的比例)。
- 计算方式:
正确预测的训练样本数 / 总训练样本数 - 意义:反映模型在“见过的数据”上表现如何。
- 初期较低(比如 0.6),随着训练逐渐上升(比如到 0.85)。
- 局限性:不能完全代表模型真实能力,因为模型已经“看过”这些数据。
3️⃣ test acc(测试精度)
- 含义:模型在从未见过的测试数据上的分类准确率。
- 为什么重要?
这才是衡量模型泛化能力(generalization)的关键指标!- 它回答了这个问题:“模型学到的知识,能用到新图片上吗?”
- 理想情况:
test acc接近train acc(比如 train acc=85%,test acc=83%)→ 模型学得扎实。 - 危险信号:
train acc很高(95%),但test acc很低(70%)→ 过拟合。
✅ 在你的 Fashion-MNIST 实验中,最终
test acc达到 80%+,说明这个简单的 Softmax 回归模型已经能有效识别衣物类别!
- 趋势:随着训练进行,损失下降,两个精度都上升。
- 关键观察:
test acc始终略低于train acc,这是正常现象;但如果差距越来越大,就要警惕过拟合。
💡 一句话总结:
train loss是优化目标,train acc是训练表现,test acc才是真实水平。
🔮 第八步:预测看看!
训练完后,我们可以让模型“看图说话”:
1 | predict_ch3(net, test_iter) |
🧠 小结:Softmax 回归的核心思想
| 步骤 | 关键操作 | 目的 |
|---|---|---|
| 输入 | 图片 → 784维向量 | 标准化输入 |
| 模型 | 线性变换 + Softmax | 输出合法概率分布 |
| 损失 | 交叉熵 | 惩罚“真实类别概率低”的预测 |
| 优化 | SGD | 最小化损失,提高精度 |
| 评估 | 精度(argmax + 比较) | 衡量实际性能 |
💡 惊人事实:这个看似简单的模型,只用了线性运算 + 一个非线性激活(Softmax),就在10分类任务上达到了80%+的准确率!
不过,当我们把 Softmax 回归用于更复杂的任务时,比如预测下一个单词,就会遇到新挑战。自然语言中词汇量可能高达几十万甚至上百万。如果直接用 Softmax 输出每个词的概率,就需要一个超大的权重矩阵(比如 512 × 100,000),不仅内存爆炸,计算也极其缓慢。这就是所谓的“大词汇表问题”。解决方案包括:层次 Softmax、负采样、词嵌入共享等技术——它们都是现代语言模型(如 Word2Vec、BERT)的重要组成部分。
🎉 结语
通过从零实现 Softmax 回归,我们不仅掌握了多分类的基本原理,还亲手构建了数据加载、模型定义、损失计算、训练循环、评估预测的完整 pipeline。更重要的是,我们在实现过程中不断思考:数值稳定性如何保证?硬预测是否总是合理?大规模输出如何处理?这些问题的答案,正是从“会调库”走向“懂原理”的关键一步。
下次当你看到 AI 自动分类邮件、识别物体、翻译语言时,不妨想想:背后可能就运行着这样一个优雅而强大的 Softmax 回归(或它的深度扩展)!