论文: Fast Inference from Transformers via Speculative Decoding
作者: Yaniv Leviathan, Matan Kalman, Yossi Matias
发表: ICML 2023 | arXiv:2211.17192

一句话总结: 在不改变输出分布的前提下,用小模型猜测 + 大模型并行验证,把自回归解码从串行瓶颈中解放出来,实现 2-3X 加速。

一、问题:自回归解码的串行瓶颈

大语言模型生成 个 token 需要串行跑 次 forward pass。而 decode 阶段通常是 memory bandwidth bound——GPU 大部分时间在等数据而不是在算数据,算力严重闲置。

1
2
GPU 有大量闲置算力,但自回归解码无法利用它。
因为下一个 token 的计算依赖上一个 token 的结果,天然串行。

Speculative Decoding 的核心问题是:能不能让 GPU 同时验证多个 token,而不是一个一个生成?


二、核心思想:猜测 + 验证

借鉴 CPU 分支预测的投机执行思想:

  1. 用一个小而快的 draft model 自回归猜 个 token( 是预设的猜测长度,比如 就是一口气猜 7 个)
  2. 个猜测一次性喂给 target model 并行计算所有位置的概率分布
  3. 从前往后逐个验证,拒绝第一个不一致的猜测
  4. 在拒绝位置从修正分布中采样替代 token

下图展示了一次完整的投机解码过程。每一行是一轮猜测-验证:绿色是被接受的猜测,红色是被拒绝的,蓝色是修正采样。target model 只跑了 9 次,却生成了 38 个 token。

Speculative Decoding 示意图:逐步猜测与验证

关键保证:最终输出分布与直接用大模型生成完全相同。不是近似,是精确等价。

1
2
3
标准解码:每生成 1 个 token,跑 1 次大模型
投机解码:每跑 1 次大模型,产出 1 到 γ+1 个 token
最坏不比标准解码差,最好一次产出 γ+1 个

三、Speculative Sampling:完整 Walkthrough

我们用一个从头到尾的具体例子,走完整个流程。

设定

假设词表只有 {A, B, C},我们让 draft model 猜 个 token。

Step 1:draft model 自回归猜 3 个 token。

draft model 每一步从自己的分布 中采样:

  • 位置 1: = {A:0.2, B:0.7, C:0.1},采到 B
  • 位置 2: = {A:0.6, B:0.1, C:0.3},采到 A
  • 位置 3: = {A:0.1, B:0.1, C:0.8},采到 C

猜测序列:[B, A, C]

Step 2:target model 并行算出每个位置的分布。

把 prefix、prefix+[B]、prefix+[B,A]、prefix+[B,A,C] 一次性喂给大模型,得到 4 个分布:

  • 位置 1: = {A:0.5, B:0.3, C:0.2}
  • 位置 2: = {A:0.7, B:0.2, C:0.1}
  • 位置 3: = {A:0.1, B:0.5, C:0.4}
  • 位置 4: = {A:0.3, B:0.4, C:0.3}(这个备用,后面解释)

注意:大模型这里只跑了一次 forward pass(输入长度为 prefix + 3 个猜测 token),利用 causal mask 同时得到所有位置的分布。

Step 3:从左到右逐个验证。

对每个位置,抽一个随机数 ,和 比较:

1
2
3
4
5
6
7
8
位置 1:猜了 B。p₁(B)/q₁(B) = 0.3/0.7 ≈ 0.43
抽到 r₁ = 0.35 < 0.43 → ✅ 接受

位置 2:猜了 A。p₂(A)/q₂(A) = 0.7/0.6 ≈ 1.17 > 1,截断为 1
无论 r₂ 是多少都 < 1 → ✅ 接受

位置 3:猜了 C。p₃(C)/q₃(C) = 0.4/0.8 = 0.5
抽到 r₃ = 0.72 > 0.5 → ❌ 拒绝

验证在位置 3 停住了。 前 2 个猜测被接受。

Step 4:在拒绝位置从修正分布采样。

位置 3 被拒绝,需要从修正分布 中重新采样:

1
2
3
p₃ - q₃:    A=0.0, B=0.4, C=-0.4
max(0, .): A=0.0, B=0.4, C=0.0
归一化: A=0.0, B=1.0, C=0.0

中采样,得到 B。

最终输出:[B, A, B] — 3 个 token,而大模型只跑了 1 次。

如果全部接受呢?

如果 3 个猜测全部通过验证,那就直接用第 4 个位置的分布 采样一个额外 token。输出 4 个 token()。这就是最好情况。

如果第 1 个就拒绝呢?

在位置 1 从修正分布采样,输出 1 个 token。这就是最坏情况——和标准解码一样,不会更差。

为什么输出分布和大模型一致?

对任意 token ,最终采到它的概率有两条路径:

第一项是”被采到且被接受”的概率,等于

第二项是”拒绝后从修正分布采到”的概率,其中:

两项之和

这就是精确等价的证明。不是近似,是严格恒等式。

总结:一次迭代的完整流程

1
2
3
4
5
6
7
8
9
10
11
┌─────────────────────────────────────────────────────┐
│ 1. Draft model 自回归猜 γ 个 token │
│ 2. Target model 一次 forward pass 得到 γ+1 个分布 │
│ 3. 从左到右验证:r < p(x)/q(x) 则接受,否则停止 │
│ 4. 在停止位置从修正分布 norm(max(0, p-q)) 采样 │
│ (或全部接受时,从第 γ+1 个分布直接采样) │
│ 5. 输出:接受的猜测 + 1 个新采样 token │
│ │
│ 保证:最少 1 个 token,最多 γ+1 个 token │
│ 保证:输出分布 = 大模型直接生成的分布 │
└─────────────────────────────────────────────────────┘

四、理论分析

期望产出

为平均 acceptance rate,即每个位置被接受的概率(反映 近似 的程度)。

一次迭代的产出 = 被接受的猜测数 + 1(最后从修正分布或 采样的 token)。

位置 被接受,要求前面 全部通过验证(一旦某个拒绝,后面就不看了)。每个独立通过的概率是 ,所以位置 被接受的概率是

这就是等比数列求和:

代几个数感受一下:

期望产出 含义
0.5 5 1.97 小模型猜得一般,平均每轮 ~2 个 token
0.8 5 4.13 小模型不错,平均每轮 ~4 个 token
0.9 7 6.13 小模型很好,平均每轮 ~6 个 token

加速比

标准解码生成 1 个 token 跑 1 次大模型,耗时

投机解码一次迭代:draft model 跑 次 + target model 跑 1 次 = ,其中 是 draft 与 target 的单次推理时间比值。

1
2
3
标准解码每 token 成本 = T
投机解码每 token 成本 = 一次迭代耗时 / 期望产出
= T(γc + 1) / [(1 - α^(γ+1)) / (1 - α)]

加速比 = 标准成本 / 投机成本:

关键结论:只要 ,就一定有加速。 实验中 ,所以 在 0.5-0.7 就足够获得 2-3X 加速。

给定 ,存在最优 低时少猜(猜多浪费), 高且 小时多猜。


五、实验结果

论文在 T5-XXL (11B) 上实验,单张 TPU-v4,batch size 1:

任务 Draft 模型 Temp 加速
英德翻译 T5-small (77M) 0 7 0.75 3.4X
英德翻译 T5-small (77M) 1 7 0.62 2.6X
摘要生成 T5-small (77M) 0 5 0.65 3.1X
摘要生成 T5-base (250M) 0 5 0.73 3.0X

核心观察:

  • argmax 解码加速更大(分布更集中,小模型更容易猜对)
  • T5-small (77M) 作为 11B 的 draft 最优,平衡了
  • 即使 bigram 做 draft(, ),也能提供 1.25X 免费加速

另外值得注意:LaMDA 137B 用 LaMDA 8B 做 draft, 达到 0.75,说明方法可以扩展到百亿级模型。


六、方法特性

不需要任何训练。 对 draft model 没有限制——同架构小模型、n-gram、从 context 复制的启发式、甚至随机模型都行。对任何已部署的大模型,找个靠谱的小模型就能立刻加速。

算力换时间。 wall-clock time 下降,但总算力不一定减少。本质上是利用 decode 阶段的闲置算力做并行验证。对 memory bandwidth bound 的场景收益最大。


七、评价与后续

Speculative Decoding 的贡献是框架性的:

  1. 正确的抽象:模型不动,推理方式可以变。系统层面的洞察,不是模型层面的优化
  2. 精确等价:加速和精确可以兼得,输出分布严格不变
  3. 开创研究路线:后续 EAGLE、Medusa、SpecInfer、Lookahead Decoding 等工作都建立在这个框架上

今天 Speculative Decoding 已是 vLLM、TensorRT-LLM、SGLang 等主流推理框架的标配组件,并演化出了多个方向,如 EAGLE,MTP 等。