Transformer:共享注意力头
在前面的章节中,我们解决了扩展Transformer时最根本的挑战之一:训练期间注意力的二次成本。稀疏注意力和FlashAttention为我们提供了处理更长序列的强大工具。
但一旦模型训练完成并部署——当它真正在和你对话时——一个不同的问题占据了主导地位。
挑战不再仅仅是计算注意力。而是内存。
具体来说,是关于一种叫做KV缓存的东西,它正是生成文本如此消耗内存的核心。要理解为什么我们需要共享注意力头,我们首先需要理解KV缓存问题。
1、LLM实际上如何生成文本
在深入技术之前,让我们先建立背景。
当你向LLM发送消息并得到回复时,模型并不是一次性产生整个回复。它逐个token地生成文本,在一个循环中。为了生成每个新词,模型会查看之前的每一个词——整个对话历史——并决定下一个是什么。
自回归生成: 模型逐个token生成,每个新token依赖于所有之前的token。这被称为自回归生成,因为每个输出都作为下一步的输入反馈回来。
这是一个优美的设计——它使模型能够在长回复中保持连贯性。但它造成了一个严重的性能问题。
2、KV缓存:一个必要的代价
问题简单来说:
每次模型生成新token时,它需要计算该新token与所有之前token之间的注意力。要计算注意力,你需要每个之前token的键和值向量。
在没有任何优化的情况下,模型会在每一步从头重新计算所有过去token的K和V向量。如果你已经生成了500个token,它为了生成第501个token会重新做全部500次计算。这极其浪费——这些向量并没有改变。
解决方案是KV缓存:将每个过去token的键和值向量存储在GPU内存中,并在后续每步中复用它们。只有最新token的K和V向量需要重新计算。
类比: 想象你在写一个很长的故事,一次一个词。没有KV缓存的话,你在写每个新词之前都要从头重新阅读整个故事。KV缓存就像保留一个持续更新的摘要表——你只需瞥一眼这个表,然后把新词添加上去。
这极大地加速了推理。现代LLM如果没有它将无法使用。
但它带来了高昂的代价:内存。
KV缓存随着每生成一个token而增长。对于一个处理10万token上下文、有32个注意力层的模型,KV缓存可能消耗数十GB的GPU内存——通常比模型权重本身还多。这限制了:
- 批次大小 — 同时服务的用户更少
- 上下文长度 — 对话可以有多长有一个硬上限
- 推理成本 — 更多内存意味着更昂贵的硬件
为什么是键和值,而不是查询? 查询只需要为当前token提出问题。一旦你为token i计算了注意力,那个查询就再也不需要了。键和值则不同,它们代表每个过去token向任何未来查询提供的内容——所以它们需要被存储和复用。
3、多头注意力:原始设计
要理解解决方案,我们首先需要理解问题的来源:多头注意力(MHA),在原始的"Attention is All You Need"论文中引入(Vaswani等,2017)。
MHA的关键洞察是,单一注意力机制可能只能捕获token之间的一种关系类型。通过并行运行多个注意力机制——每个被称为一个"头"——模型可以同时寻找不同的模式:语法、语义、共指、位置关系等等。
在有H个头的标准MHA中,每个头有自己独立的一组:
- Q(查询)投影 — 这个token在寻找什么
- K(键)投影 — 这个token向搜索者提供什么
- V(值)投影 — 这个token携带的信息
这看起来像:
Head 1: Q₁, K₁, V₁ → 独立的注意力输出₁
Head 2: Q₂, K₂, V₂ → 独立的注意力输出₂
...
Head H: Qₕ, Kₕ, Vₕ → 独立的注意力输出ₕ
最终输出 = concat(输出₁, 输出₂, ..., 输出ₕ) × Wₒ
每个头学习关注序列的不同方面。输出然后被拼接并投影回模型的嵌入空间。
为什么它效果极好: MHA给模型提供了每个token丰富的多视角视图。不同的头自然地专业化——一些追踪动词指的是哪个主语,其他追踪句法结构,还有一些关注标点模式。
为什么在推理时成为问题: 每个头都有自己的K和V矩阵。它们全部需要为每个token存储在KV缓存中。有H个头和L层,KV缓存大小按以下方式扩展:
KV缓存大小 ∝ n × H × L × d_h × 2
其中n是序列长度,d_h是每个头的维度。乘以2是因为你需要K和V两者。对于一个有32个头和32层的大模型,这增长得非常快。
这是启发下一代注意力设计的内存墙。
4、多查询注意力:激进的效率
多查询注意力(MQA)由Noam Shazeer在2019年提出,其想法简单得令人惊讶:
如果所有查询头共享一个单一的键和值头呢?
MQA不是使用H个独立的K和V投影,而是只使用一个:
Head 1: Q₁ ─┐
Head 2: Q₂ ─┤── K(共享),V(共享)
... │
Head H: Qₕ ─┘
每个查询头仍然提出不同的问题(不同的Q投影),但它们都从相同的K和V中查找答案。KV缓存缩减了H倍——从H组K和V向量减少到只有一组。
对内存的影响是戏剧性的。 对于一个有32个头的模型,MQA将KV缓存减少到原来大小的1/32。
直觉: 想象一个由32名研究人员组成的团队,每个人都有自己独特的研究问题(查询)。在标准MHA中,每名研究人员维护自己的私人事实库(K和V)。在MQA中,他们都共享一个单一的公共库。研究人员仍然提出不同的问题,但他们都从同一个来源查找答案。
权衡: 共享单一K/V头是相当激进的。这意味着所有查询头都被限制在基于相同的键比较来进行注意力——它们只是以不同方式加权相同的值。这可能损害模型质量,特别是在需要细致的、多方面推理的任务上。研究已经证实,与完整MHA相比,MQA可能导致质量下降和训练不稳定。
在哪里使用: PaLM、StarCoder、Falcon 7B和其他早期大型模型采用MQA以获得推理效率。
📘 参考文献:Shazeer(2019)。Multi-Query Attention: One write-head is all you need。
5、分组查询注意力:恰到好处的解决方案
分组查询注意力(GQA)由Ainslie等人在2023年引入,并迅速成为现代LLM中的主导方法。关键洞察很简单:
MHA(H个K/V头)太昂贵。MQA(1个K/V头)损害质量。那G个K/V头呢,其中1 < G < H?
GQA将查询头分为G组。每组共享一个K/V头:
组 1: Q₁, Q₂, Q₃, Q₄ ── K₁, V₁ (组内共享)
组 2: Q₅, Q₆, Q₇, Q₈ ── K₂, V₂ (组内共享)
...
组 G: Q_{H-3}...Qₕ ── Kɢ, Vɢ (组内共享)
这创造了一个平滑的光谱:
- G = H的GQA 等同于标准MHA(每个查询头有自己的K/V)
- G = 1的GQA 等同于MQA(所有查询头共享一个K/V)
- G在1和H之间的GQA 是有趣的中间地带
直觉: 回到研究团队的类比。GQA就像将你的32名研究人员组织成8个4人小组。每个小组在其4名成员之间共享一个参考库,但不同小组有不同的库。这比32个私人库高效得多,同时给每个小组比所有人的单一公共库更专业的知识。
数字: 从8B到405B参数的Llama 3模型都使用G=8个键/值头的GQA。这意味着KV缓存是完整MHA大小的8/32 = 1/4,同时保留了大部分质量。
为什么特别是G=8? 这个选择平衡了质量和效率。Ainslie等人的研究表明,8组以与MQA相当的速度捕获了MHA大部分的表达能力。低于4组开始明显损害质量;超过16组提供递减的内存节省。
为什么GQA获胜: 这是一个罕见的近乎帕累托改进的案例——推理时几乎和MQA一样快,建模时几乎和MHA一样准确。它还能很好地与跨多GPU的张量并行配合,这也是大型模型实际部署的方式。
今天谁在使用:
📘 参考文献:Ainslie等(2023)。GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints。EMNLP 2023。
6、完整光谱:可视化MHA、GQA、MQA
并排看这三种设计很有帮助。以H = 8个头为例:
多头注意力(MHA):
Q₁K₁V₁ Q₂K₂V₂ Q₃K₃V₃ Q₄K₄V₄ Q₅K₅V₅ Q₆K₆V₆ Q₇K₇V₇ Q₈K₈V₈
(8个K/V头 — 最高质量,最大内存)
分组查询注意力(GQA, G=2):
Q₁Q₂Q₃Q₄──K₁V₁ Q₅Q₆Q₇Q₈──K₂V₂
(2个K/V头 - 平衡的质量和内存)
多查询注意力(MQA):
Q₁ Q₂ Q₃ Q₄ Q₅ Q₆ Q₇ Q₈ ──── K₁V₁
(1个K/V头 - 最小内存,潜在质量损失)
7、多头潜在注意力:下一个前沿
虽然GQA减少了K/V头的数量,但 多头潜在注意力(MLA)采取了根本不同的方法,由DeepSeek在其V2模型(2024年)中引入,并在DeepSeek-V3和R1中继续使用。
核心理念:不是减少头的数量,而是将K/V向量本身压缩到一个低维潜在空间中后再存储。
以下是概念上的工作方式:
把它想象成文件的压缩算法。MLA不是存储完整的K和V矩阵,而是存储它们高度压缩的"草稿"。当未来的查询需要它们时,它即时将草稿解压。
技术上: MLA使用低秩矩阵分解(在精神上类似于LoRA微调技术)。输入被向下投影到一个小的潜在向量中,而向上投影在计算时重建完整的K和V。只有微小的潜在向量需要存在于KV缓存中。
结果是惊人的。 DeepSeek-V2将其KV缓存比前代减少了93.3%。这转化为5.76倍的更高生成吞吐量——不仅仅是内存节省,而是相同硬件上显著更快的响应。根据一项研究,将基于GQA的模型如LLaMA-2转换为MLA在8K上下文长度下实现了10.6倍的推理加速。
声明: 与MQA和GQA——它们明确减少注意力的表达能力——不同,DeepSeek声称MLA在不牺牲模型质量的情况下实现了KV压缩,甚至可能略有提升。独立评估总体上确认了有竞争力的性能,尽管情况是复杂的。
代价: MLA增加了架构复杂性。解压步骤意味着推理时更多的矩阵乘法。它更难改写到使用MHA或GQA训练的现有模型上。而且当整个生态系统围绕它设计时(自定义推理内核、量化等),收益才最大化。
📘 参考文献:DeepSeek-AI(2024)。DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model。
8、总结:推理中注意力的演进
让我们退后一步看完整的弧线:
其中d_c ≪ H × d_h — 潜在维度远小于完整的K/V维度。
这一演进中的每一步都由相同的底层压力驱动:随着上下文窗口变得更长,LLM服务更多用户,KV缓存成为关键瓶颈,而不仅仅是小开销。
8.1 张量并行的角度
一个值得理解的实际细节:现代LLM通常在多个GPU上并行运行(这被称为张量并行)。每个GPU处理一部分注意力头。
在标准MHA中,跨GPU分配很简单——每个GPU获得自己的头及完整的K和V。但在MQA中,只有一组K/V,这意味着每个GPU都需要一份副本。这种复制浪费了内存带宽——这是GQA(而非MQA)成为主导选择的关键原因之一。
GQA的组自然地映射到张量并行配置。如果你有8个GPU和8个K/V组,每个GPU恰好处理一个组——干净、高效,没有浪费的复制。
8.2 一个具体的直觉:研究图书馆再回顾
让我们用一个一致的类比将整章内容串联起来。
想象一个由32名研究人员组成的团队(= 32个注意力头)在一个大型研究项目上工作(= 生成长回复)。每名研究人员需要查阅历史档案(= KV缓存)中到目前为止写过的所有内容。
多头注意力: 每名研究人员有自己的私人完整档案。最大知识,最大空间——需要维护32个完整档案。
多查询注意力: 所有32名研究人员共享一个单一的公共档案。最小空间——但有专业需求的研究人员发现共享档案有时太泛化。
分组查询注意力: 研究人员被组织成8个4人小组。每个小组共享一个与其专业相关的聚焦档案。良好的平衡——8个档案而不是32个,每个比单一共享档案更专业。
多头潜在注意力: 每名研究人员存储一张压缩摘要卡片而不是完整档案。完整档案可以在需要时从卡片重建。最小存储,(理想情况下)最小质量损失。
9、为什么这超越了架构论文也很重要
对于大多数读者来说,你永远不会从头实现这些机制。但理解它们在实际中有重要意义:
在选择模型时: 使用G=8的GQA的模型将比使用完整MHA的相同模型在同一硬件成本下服务两倍的并发用户。这直接影响你是否能负担得起部署。
在阅读基准测试时: 模型之间的推理速度比较常常被注意力变体所混淆。一个初看较慢的模型可能只是使用MHA,而竞争对手使用GQA或MLA。
在理解上下文限制时: 模型的"200K上下文"声明不仅仅关于训练——它直接由GQA或MLA等KV缓存优化实现,使存储200K个token的K/V向量变得可行。
在追踪领域进展时: MHA → MQA → GQA → MLA的进展代表了一个活跃的研究前沿。理解这种演进为你解释下一代注意力变体做好了准备。
原文链接: Transformers & LLMs — Part 8: Sharing Attention Heads
汇智网翻译整理,转载请标明出处