从 AdamW 到 Muon 优化器
在本文中,我们将讨论一种新的优化器——Muon,它被用于 Kimi2 的训练(一个拥有 1T 参数和 32B 活跃参数的模型)。Muon 优化器生成的损失曲线是每一位 ML 研究者的梦想。
我们将从简要讨论深度学习中的 Adam 和 AdamW 开始,然后深入探讨 Muon 及其在训练和参数更新中的作用。
有关 AdamW 的更全面的博客文章,请参阅这里,或这个 YouTube 播放列表
1、Adam 和 AdamW
Adam 是之前所有优化器精妙技巧的集大成者。它包含动量(momentum),即梯度的移动平均值。它从 RMProp 中获得了梯度方差的灵感(方差的移动平均值),并且还对移动平均值初始值的零初始化进行了偏差修正。
这里,β₁ 和 β₂ 分别是用于计算梯度(动量)和梯度平方(方差)移动平均值的常数。由于 m 和 v 的初始值被设为零,移动平均值在早期步骤中会偏向零。这导致训练开始时更新幅度较小。为了修正这种初始化偏差,我们用期望值对 m 和 v 进行归一化,有效地将它们缩放到反映其真实大小的值。
将所有内容整合在一起
2、Adam 和 AdamW 中的正则化与 Bug 修复
在深度学习和机器学习中,正则化通常通过将权重的 L1 或 L2 范数加到损失函数中来实现。这有助于通过将权重的大小与整体损失联系起来来约束权重的大小,防止它们在训练过程中增长过大。
正则化需要等同于权重衰减,这能提高我们神经网络的泛化能力并限制权重的更新。基于动量的更新中,权重衰减的形式如下
然而,对于 Adam,如果你按照这里的推导,它是
权重衰减中有一个额外的因子,甚至削弱了权重衰减的效果
因此,在 AdamW 中,作者没有将权重衰减加到梯度上,而是将其加到最终的更新中,这样就没有额外的因子了。
3、为什么我们需要一个新的优化器?
Adam 和 AdamW 需要两样东西:(i) 一阶矩,(ii) 二阶矩。这些通常以全精度(fp32 或 tf32)存储,消耗的内存是权重参数数量的 8 倍(一阶矩 4 字节,二阶矩 4 字节)。
此外,由于这些是移动平均值,我们需要对梯度进行全收集(all-gather)来计算一阶矩和二阶矩,这会带来大量的通信开销。
需要正交性(以下是 Muon 优化器博客文章中的一段摘录)
出于经验性的动机,我们观察到通过手动检查,SGD-momentum 和 Adam 为基于 Transformer 的神经网络中的 2D 参数生成的更新通常具有非常高的条件数。也就是说,它们几乎是低秩矩阵,所有神经元的更新仅被少数几个方向所主导。我们推测正交化有效地增加了其他"稀有方向"的尺度,这些方向在更新中幅度很小但对学习至关重要。
我们在各种形式中应用归一化,例如层内的层归一化、跨批次的批归一化以及归一化的权重初始化。这些实践表明归一化有助于提高泛化能力并加速训练收敛。为什么不对梯度也进行归一化呢?
4、Muon
Muon 的公式非常简单;它不需要跟踪一阶矩和二阶矩,只保留动量(Nesterov 动量)。它在进行权重更新之前对动量矩阵进行归一化。
对于归一化,它使用 Newton-Schulz 方法而不是奇异值分解(SVD)。SVD 对于大小为 R^{m x n} 的矩阵的时间复杂度为 O(mn²)。我们将对大型前馈层应用归一化,这会带来大量的计算开销。因此,作者使用 Newton-Schulz 方法。
PyTorch 代码
# Pytorch code
def newtonschulz5(G, steps=5, eps=1e-7):
assert G.ndim == 2
#the coefficients derived empirically
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
#normalize the matrix with the max-norm so that values
#have a range of (0,1)
X /= (X.norm() + eps)
#We are going matmul the X matrix with its transpose
#which will be of dimension row x row.
#To reduce computational complexity, the row < column
#hence take a transpose
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
#b and c are scalar-matrix multiplication
#c is multiplied with degree 5 (X @ X.T squares the matrix)
B = b * A + c * A @ A
X = a * X + B @ X
#flip back the dimension
if G.size(0) > G.size(1):
X = X.T
return X
一个巧妙的技巧是翻转矩阵,使我们始终处理宽矩阵(列数大于行数)而不是高矩阵(列数小于行数)。这减少了矩阵 X @ X.T 中的维度数。
与 Adam 相比的另一个改进是只使用一个缓冲区而不是两个。Adam 有一阶矩和二阶矩缓冲区,但 Muon 只有一个动量缓冲区。这将每个参数的内存占用减少了 50%(从 8 字节减少到 4 字节)。
Muon 是范数约束下的最速下降:在 Muon 中,我们的梯度更新被归一化并位于 Schatten p-范数(或 SVD 中奇异值的范数)中。如果我们取 SVD 中奇异值的最大值,就得到了谱范数,这在 PCA 中使用。此外,归一化是静态的,因此我们得到稳定的更新。
Adam 和 AdamW 对动量有动态适应性;因此,所有权重维度都被动态更新,这可能在更新中缺乏稳定性。
5、扩展 Muon
在 Muon Clip 论文中,作者做了更多技巧来扩展 Muon
5.1 权重衰减
他们将权重衰减添加到更新中,因为他们观察到随着训练的进行,权重和 RMS 范数持续增长。
按矩阵维度缩放(乘以):
- 假设矩阵的维度是 A x B。在 Newton-Schulz 方法中,我们按维度归一化矩阵,因此对于较大的矩阵,更新会较小,而对于较小的矩阵(Q、K 和 V 投影矩阵的头部维度)更新会较大。 因此,作者按矩阵维度的平方根来放大更新。
匹配 Adam 的 RMS 范数更新:
- 在这篇博客文章中,Muon 仅用于线性矩阵,不用于非矩阵操作,如嵌入维度,后者具有独热输入。这种稀疏性会破坏归一化。
- 因此,AdamW 仍用于 RMSNorm、LM 头部和嵌入参数等矩阵。由于我们保持 Muon 和 AdamW 的超参数相同,我们需要匹配 AdamW 的 RMS 范数。
- 经验上,这个值大约在 0.2 到 0.4 之间;因此,作者将更新缩放了 0.2。最终,我们得到
注意:作者使用了 5 次 Newton-Schulz 迭代,即我们对梯度矩阵进行 5 次迭代的归一化。
6、论文的要点
与 AdamW 相比,Muon 沿着更多的维度进行优化。这从论文中 SVD 熵随时间步变化的图中可以明显看出。
如果预训练优化器也是 Muon,那么 Muon 在 SFT 中的效果也很好。
原文链接: From AdamW to Muon Optimizer
汇智网翻译整理,转载请标明出处