损失与优化#
损失函数#
损失函数定义了模型要优化的目标,即衡量模型预测值与真实值之间的差距。损失函数输出一个数值,该数值代表了单次预测的好坏程度。损失值越大,说明模型的预测越不准确。训练的目标就是最小化损失值。
回归任务的目标是预测一个连续值。
均方误差 (MSE):也称为 L2 损失。它计算的是预测值与真实值之差的平方的平均值。由于平方的存在,MSE 对较大的误差给予更重的惩罚。
\[ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 \]在一些教材或理论推导中会将 \(\frac{1}{n}\) 写为 \(\frac{1}{2}\),这是为了在求导时使系数抵消,使表达式更简洁。对于优化过程而言,这个常数系数不影响梯度的方向,因此不会改变模型参数优化的最终结果。现代深度学习框架的默认实现通常采用求均值(除以n)的方式。
平均绝对误差 (MAE):也称为 L1 损失。它计算的是预测值与真实值之差的绝对值的平均值。相比 MSE,MAE 对异常值不那么敏感。
\[ \text{MAE} = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i| \]Huber 损失 (Huber Loss):可以看作是 MSE 和 MAE 的结合。它通过一个超参数 \(\delta\) 来控制:当误差小于 \(\delta\) 时,它等同于 MSE(平方损失);当误差大于 \(\delta\) 时,它等同于 MAE(线性损失)。这使得它在误差较小时表现平滑,同时对大的异常值不那么敏感。
\[\begin{split} L_\delta(y, f(x)) = \begin{cases} \frac{1}{2}(y - f(x))^2 & \text{for } |y - f(x)| \le \delta \\ \delta \cdot (|y - f(x)| - \frac{1}{2}\delta) & \text{otherwise} \end{cases} \end{split}\]
分类任务的目标是预测一个离散的类别。
交叉熵损失 (Cross-Entropy Loss):是分类任务中最常用的损失函数。其梯度是真实概率与预测概率的区别。它衡量的是模型预测的概率分布与真实的概率分布之间的差异。
二元交叉熵 (Binary Cross-Entropy):用于二分类问题。输出层只有一个神经元,并使用 Sigmoid 激活函数输出一个概率值 \(p\)。
\[ L = -[y \log(p) + (1-y) \log(1-p)] \]其中 \(y\) 是真实标签(0 或 1),\(p\) 是模型预测为类别 1 的概率。
分类交叉熵 (Categorical Cross-Entropy):用于多分类问题。输出层有 N 个神经元(N 为类别数),并使用 Softmax 激活函数输出一个概率分布。
\[ L = -\sum_{i=1}^{N} y_i \log(p_i) \]其中 \(y_i\) 是一个独热编码的真实标签,\(p_i\) 是模型预测为类别 \(i\) 的概率。
优化器#
优化器是根据反向传播计算出的梯度来更新网络参数(权重和偏置)的算法,它的目标是找到一组能使损失函数最小化的参数。
梯度下降(Gradient Descent)是所有优化算法的基础。其核心思想是:沿着梯度下降最快的方向(梯度的反方向)调整参数,从而逐步减小损失值。
更新规则如下:
新参数 = 旧参数 - 学习率 × 梯度
学习率 是一个超参数,它控制着每次参数更新的步长。学习率过大可能导致模型在最优点附近震荡甚至发散;学习率过小则会导致模型收敛速度过慢。
根据每次更新使用的数据量,梯度下降分为三种变体:
批量梯度下降 (Batch Gradient Descent):每次更新都使用整个训练集的数据,计算精确但速度慢。
随机梯度下降 (Stochastic Gradient Descent, SGD):每次更新仅使用训练集中的一个样本。速度快,但更新方向不稳定,损失函数会剧烈波动。
小批量梯度下降 (Mini-batch Gradient Descent):是上述两者的折中,也是最常用的方法。每次更新使用一小批(mini-batch)数据(如 32, 64, 128 个样本)。
另两个著名的优化器是Momentum和Adam:
Momentum:引入冲量的概念。它在更新参数时,不仅考虑当前的梯度,还考虑历史的更新方向。
Adam (Adaptive Moment Estimation):是目前最流行、最常用的优化器之一。它结合了 Momentum 和另一种名为 RMSprop 的优化器的思想,能够为网络中的每一个参数计算自适应的学习率。Adam 的优势在于对学习率不不那么敏感,但效果可能比 SGD+Momentum 差一点。