LLM 中 100K 上下文窗口背后的秘密

最近有几则关于新型大型语言模型 (LLM) 的公告,这些模型可以使用极大的上下文窗口,例如65K 个标记(MosaicML 的MPT-7B-StoryWriter-65k+ )甚至100K 个标记( Antropic 的引入 100K 上下文窗口)。在 Palm-2技术报告中,Google 没有透露上下文大小,但提到它们“显著增加了模型的上下文长度”。

作为比较,当前的 GPT-4 模型可以处理32K 个输入标记的上下文长度。而大多数开源 LLM 的上下文长度为2K 个标记

这令人印象深刻,因为拥有如此长的上下文长度意味着提示可以有一本书那么大。《了不起的盖茨比》有 72K 个词,210 页,阅读需要 6 个小时,阅读速度为 1.7 分钟/页。因此,模型可以扫描并保留这么多“自定义”信息来处理查询!

我一直在努力弄清楚这在技术上是如何实现的,因此在这篇博文中,我收集了一些零散的信息(这个帖子是第一个线索)并涵盖了以下内容:

  • 为什么上下文长度很重要以及为什么它可以改变游戏规则
  • 在处理较长的上下文长度时,原始 Transformer 架构主要限制是什么
  • Transformer 架构的计算复杂度
  • 目前有哪些优化技术可以加速 Transformer 并将上下文长度增加到 100K

“简短”摘要

在这里以及后面,我们交替使用“上下文长度”、“上下文窗口”和“输入标记的数量”,将它们表示为n

博文有点长,因此有一个总结,其中包含要点和技巧:

  • 第一个问题是注意层计算相对于输入标记数量n的二次时间和空间复杂度。
  • 当嵌入大小d > n时,第二个问题是关于嵌入大小d的线性层的二次时间复杂度。
  • 第三个问题是原始架构中使用的位置正弦嵌入。
  • 在 Transformer 架构中,可学习矩阵权重的形状与输入标记n的数量无关 。
  • 因此,经过训练的 Transformer 具有 2K 上下文长度可以使用任意长度的 token甚至 100K。但如果模型未经训练,在推理过程中将无法对 100K 个 token 产生有意义的结果 100K。
  • 由于nd的二次复杂度,在大型语料库上训练 vanilla Transformer 的成本高得不可行。估计2K 语料长度的 LLaMA 的训练成本约为 300 万美元。因此,100K 语料长度的 LLaMA 的成本约为1.5 亿美元。
  • 一种选择是在 2K 个 token 上下文中训练模型,然后在更长的上下文(例如 65K)中对其进行微调。但由于位置正弦编码,它无法与原始 Transformer 一起使用
  • [技巧 1]要解决这个问题,请删除位置正弦编码并使用ALiBi,这是一种简单而优雅的位置嵌入,不会影响准确性。然后您可以在 2K 上进行训练,并在 100K 上进行微调。
  • [技巧 #2]您不需要计算所有标记之间的注意力分数。有些标记比其他标记更重要,因此可以使用稀疏注意力。它将加快训练和推理的速度
  • [技巧 3] Flash Attention高效地为 GPU 实现了注意层。它使用平铺并避免了无法放入 GPU SRAM的大型中间矩阵(n, n)的实现。它将加快训练和推理的速度
  • [技巧 4]使用多查询注意力机制代替多头注意力机制。这意味着在线性投影 K 和 V 时,所有头之间共享权重。这大大 加快了增量推理的速度。
  • [技巧 #5] 条件计算避免将所有模型参数应用于输入序列中的所有标记。CoLT5仅对最重要的标记应用大量计算,并使用较轻版本的层处理其余标记。这将加快训练和推理的速度
  • [技巧#6]为了适应大环境,你需要 GPU 中大量的 RAM,因此人们使用 80GB A100 GPU。

总而言之,训练和推理的速度越快,可以使用的上下文长度就越大

现在让我们更详细地讨论所有这些要点。

为什么上下文长度很重要

上下文长度是 LLM 的关键限制之一。将其增加到 100K 是一项了不起的成就(我想知道一年后这句话会是什么样子)。

人们想要应用 LLM 的一个重要用例是“将一大堆自定义数据放入 LLM”(与公司或特定问题相关的文档、各种异构文本等)并询问有关这些特定数据的问题,而不是 LLM 在训练期间从互联网上看到的一些抽象数据。

为了克服这个限制,人们采取了各种各样的措施:

  • 尝试总结技巧和复杂的链式提示
  • 维护向量数据库以保留自定义文档的嵌入,然后通过某种相似性度量对它们进行“搜索”
  • 在可能的情况下使用自定义数据对 LLM 进行微调(并非所有商业 LLM 都允许这样做,而且对于开源 LLM 来说这不是一项显而易见的任务)
  • 为该特定数据开发定制的较小 LLM(同样,这不是一项显而易见的任务)

拥有较大的上下文长度可以让已经很强大的 LLM(可以查看整个互联网)查看您的上下文和数据,并以完全不同的方式与您进行更个性化的互动。所有这些都不会改变模型的权重,也不会在“内存中”动态进行“训练”。总的来说,较大的上下文窗口为模型带来了更高的准确性、流畅性和创造力。

这里的一个类比可能是计算机 RAM,操作系统在其中保存所有应用程序的实时上下文。凭借大量的上下文长度,LLM 可以像一台“推理计算机”,保存大量用户上下文。

原始 Transformer 和上下文长度

值得注意的是,在 Transformer 架构中,所有可学习矩阵权重的形状都不依赖于输入标记的数量 n。所有可训练参数(嵌入查找、投影层、softmax 层和注意层)都不依赖于输入长度,并且必须处理可变长度的输入。我们拥有架构的这种开箱即用属性真是太好了

这意味着,如果您训练了一个上下文长度为 2K 的 Transformer 模型,则可以推断任意大小的 token 序列。唯一的问题是,如果模型未 在 100K 上下文长度上进行训练,则在推理过程中,该模型将无法对 100K token 产生有意义的结果。在这种情况下,训练数据分布将远离推理过程中的分布,因此该模型将像此设置中的任何机器学习模型一样失败。

训练大型上下文长度 Transformer 的一个解决方案是分两个阶段进行训练:在 2K 标记上下文长度上训练基础模型,然后在更长的上下文(例如 65K 或 100K)上继续训练(微调)。这正是 MosaicML所做的。但问题是它不适用于原始 Transformer 架构,因此您需要使用一些技巧(请参阅文章后面的技巧 #1)。

回顾多头注意力机制

上下文长度较大的挑战与 Transformer 架构的计算复杂度有关。要讨论复杂性,首先让我们回顾一下注意力层的工作原理。

Q —查询,K — 键和V — 值,这些符号来自与信息检索相关的论文,其中你向系统插入一个“查询”并搜索最接近的“键”
n — 输入的标记数
d — 文本嵌入维度
h — 注意力头的数量
k — Q 和 K 的线性投影大小
v — V 的线性投影大小

多头注意力机制:

  1. 我们有一个查找嵌入层,对于给定的标记,它返回一个大小为(1, d)的向量。因此,对于n 个标记序列,我们得到大小为(n, d)的文本嵌入矩阵X。然后我们将其与位置正弦嵌入相加。
  2. 多头注意力层旨在计算此标记序列的新嵌入,该序列可被视为对 X 进行编码的原始文本,通过 (1) 标记相对于上下文的相对重要性和 (2) 标记的相对位置进行加权。
  3. 我们并行处理这个嵌入矩阵 X (n, d)h 个注意力层 (注意力头)。要获得所有注意力头的QKV,您需要分别将 X 线性投影到kkv维。您可以将 X 乘以h 个形状分别为(d, k)(d, k)(d, v)的矩阵。您可以将其视为将(n, d)乘以(h, d, k)(h, d, k)(h, d, v)
  4. 注意力头返回h 个大小为 (n, v) 的注意力分数矩阵然后我们将所有注意力头(n, h*v)的部分连接起来,并对其进行线性投影,以用于下一步。
LLM 中 100K 上下文窗口背后的秘密
《Attention is All You Need》论文中注意力架构的高级示意图

缩放点积注意力机制

现在,让我们放大一个注意力头

  1. QKV是X的 3 个线性投影,大小分别为(n, k)(n, k)(n, v) ,通过乘以每个头部可学习的权重获得。
  2. 我们通过计算QK (转置)之间的距离(点积)来获得注意力分数。将矩阵(n, k)乘以(k, n)得到矩阵(n, n)。然后我们将其乘以掩码矩阵以将一些标记归零(解码器中需要)。然后我们缩放它并应用 softmax 使其从 0 到 1。这样,我们得到了形状为(n, n)的矩阵,其中n_ij –第 i个和第 j个标记之间的相对注意力分数从 0 到 1,显示了这些标记在长度为n的特定上下文中的“接近程度” 。
  3. 然后我们将这个注意力分数矩阵(n, n)乘以大小为(n, d)的“值”V ,以获得由这些相对注意力分数加权的文本嵌入。
LLM 中 100K 上下文窗口背后的秘密
原论文中,一个头部的Attention Score矩阵就是通过这个公式计算出来的。

让我们看一下多查询注意力论文中的这段代码。它展示了如何使用批处理计算多头注意力,并且每一步的形状都很清晰。它们还包括解码过程中使用的掩码乘法。

LLM 中 100K 上下文窗口背后的秘密
一段非常漂亮的代码,展示了注意力层中每一步的形状。来自Multi-Query论文。

Transformer 的复杂度和上下文长度

2矩阵乘法 (a,b)*(b,c) 的复杂度为O(a*b*c)。为简单起见
我们假设k*h = O(d) ,我们将以此来推导注意力机制的复杂度。

注意力层的复杂性由两部分组成:

  1. 线性投影得到 Q、K、V:将大小为(n, d)的嵌入矩阵乘以h 个可学习矩阵(d, k)(d, k)(d, v)。因此,复杂度约为O(nd²)
  2. 将Q与变换后的K相乘,然后与V 相乘:(n,k) * (k,n) = (n,n)(n,n)*(n,v) = (n,v)。复杂度 ~ O(n²d)

因此,注意层的复杂度为O(n²d + nd²),其中n — 是上下文长度(输入标记的数量),d — 嵌入大小。因此从这里我们可以看出,注意层计算的复杂度是输入标记数量n的二次方,也是嵌入大小 d 的二次方

当 d > n 时,项O(nd²)很重要(例如,在 LLaMa 中,n=2K 和 d=4K)。当 n > d 时
,项O(n²d) 重要(例如,训练 MosaicML 时 n=65K 和 d=4K)。

只是为了提醒你二次增长有多糟糕:
2 000² = 4 000 000,100 000² = 10 000 000 000。

让我举个例子来说明这种二次复杂度如何影响模型训练的价格。训练 LLaMa 的估计价格约为300 万美元,它有 65B 个参数、2K 个上下文长度和 4K 个嵌入大小。估计时间主要是 GPU 训练时间。如果我们将上下文长度从 2K 增加到 100K(50 倍),训练时间也会增加约 50 倍(我们需要更少的迭代,因为上下文更大,但每次迭代都需要更长的时间)。因此,在 100K 上下文上训练 LLaMA 的成本约为1.5 亿美元

关于此计算的一些细节:

对于 token 数量等于n 的情况,注意力机制的复杂度为O(n²d + nd²) ,需要M次迭代才能训练完成。如果我们将上下文长度从n → p*n 增加,则由于上下文长度变大,需要M/p次迭代(为了简单起见,我们假设它是线性的,但根据任务的不同,可能会被高估或低估)。现在我们有 2 个等式:(1)n的复杂度~ M * (n²d + nd²)(2)p*n 的复杂度~ M/p * ((p*n)²d + (p*n)d²)经过一系列简化和划分后,比率 (2)/(1) ~ (d + p*n)/(d + n)

如果d << n,则将n增加p倍将导致迭代次数增加 ~ p 倍。如果d ~ n,则将n增加p倍将导致迭代次数增加 ~ p/2 倍

Transformer 中训练阶段和推理阶段的区别

在深入研究优化技术之前要讨论的最后一件事是训练和推理期间计算的差异。

在训练过程中,您可以并行运行,而在推理过程中生成文本时,您需要按顺序执行,因为下一个标记取决于前一个标记。实现推理的直接方法是逐步计算注意力得分,并缓存以前的结果以供将来的标记使用。

这种区别带来了加速训练和推理的不同方法。这就是为什么下面的一些技巧会优化两个阶段,但有些技巧只会优化推理。

增加上下文长度的优化技术

现在,让我们来谈谈研究人员如何克服所有这些挑战并能够训练具有较大上下文长度的 LLM。

[技巧#1]更好的位置编码——ALiBi

训练大上下文长度 Transformer 的一个解决方案是分两个阶段进行训练:在 2K 个 token 上下文长度上训练基础模型,然后在更长的上下文(例如 65K)上进行微调。但之前我们说过,它不适用于原始 Transformer 架构。为什么?

因为位置正弦编码没有“外推”能力。ALiBI [4] 论文中,作者表明位置正弦编码对推理期间上下文窗口的扩展不具有鲁棒性。在多几个 token 之后,性能开始下降。因此,缺乏“外推”能力基本上意味着您在推理/微调期间不能使用比训练期间更大的上下文长度。术语“外推”和各种位置编码的比较在 [4] 中进行了描述。

在原始的 transformer 论文中,位置正弦嵌入与架构底部的 token 嵌入相加,以添加有关单词顺序的信息。如果你想了解位置正弦嵌入是如何计算的,我推荐这个有趣的视频,其中的解释直观且详细。

因此,第一个技巧是删除位置正弦嵌入,并用另一个位置嵌入替换它——具有线性偏差的注意力(ALiBI)

应用于注意力头(而不是网络底部),并且它以与它们的距离成比例的惩罚来偏置查询关键注意力分数(在 softmax 之前)。

LLM 中 100K 上下文窗口背后的秘密

此技巧可加快训练速度

LLM 中 100K 上下文窗口背后的秘密
在计算每个头部的注意力分数时,ALiBi 会为每个注意力分数(qi·kj,左)添加一个常数偏差(右)。与未修改的注意力子层一样,softmax 函数随后会应用于这些分数,其余计算则不做修改。m 是特定于头部的标量,在整个训练过程中都会设置,但不会进行学习。来自ALiBI 论文

[技巧2] 稀疏注意力

并非所有 100K 大小的上下文中的标记都彼此相关。减少计算量的一种方法是在计算注意力得分时仅考虑某些标记。增加稀疏性的目的是使计算与 n 呈线性关系,而不是二次函数。有几种方法可以选择标记之间的连接,Google博客文章中有一个很好的说明:

LLM 中 100K 上下文窗口背后的秘密
可以将全注意力机制视为一个完整的图。稀疏注意力方法
LLM 中 100K 上下文窗口背后的秘密
稀疏注意力方法

例如,滑动窗口注意力(也称为局部注意力)采用围绕每个标记的固定大小窗口注意力。在此注意力模式中,给定固定窗口大小w,每个标记会关注两侧的w /2 个标记。此模式的计算复杂度为O(n*w),与输入序列长度n成线性比例。为了提高效率,w应该小于n。诀窍在于注意力信息在邻近标记内“流动”整个上下文窗口,近似整个图。

BigBird注意力得分方法结合了全局、局部和随机机制。在论文中,作者展示了一个关键的观察结果:计算的相似度得分数量与不同节点之间的信息流(即一个 token 相互影响的能力)之间存在内在的矛盾。

这个技巧可以加快训练和推理的速度

[技巧#3] FlashAttention — GPU 注意力层的有效实现

注意力层中有几个计算操作被一遍又一遍地重复:

  1. 数量=Q*K
  2. P = softmax(S)
  3. =P*V

记住PSO结果的概念;我们稍后会用到它。FlashAttention 的作者“融合”了这些操作:他们实现了一个注意力层算法,该算法可以高效利用 GPU 内存并计算出精确的注意力。

GPU 要执行操作,输入数据必须存在于名为 SRAM 的“快速”内存中。数据从“慢速”HBM 内存复制到 SRAM,计算结束后再返回 HBM。SRAM内存比 HBM 快得多,但大小要小得多(20MB 对比 A100 40GB GPU 中的 40GB)。

LLM 中 100K 上下文窗口背后的秘密
A100 GPU 内存层次结构。FlashAttention论文

因此,访问 HBM 是一项昂贵的操作

注意力层中关于 GPU 内存利用率的主要问题是“中间”乘法结果P、SO,其大小为(n, n)。我们需要将它们保存到 HBM,并在注意力操作之间再次读取它们。将 P、S 和 O 从 HBM 移到 SRAM 并来回移动是瓶颈,作者在论文中解决了这个问题。

FlashAttention 算法背后的主要思想是将输入 Q、K 和 V 矩阵拆分为块,将这些块从 HBM 加载到 SRAM,然后计算这些块的注意力输出。此过程称为平铺

LLM 中 100K 上下文窗口背后的秘密
左图: FlashAttention 使用平铺来防止大型 n × n 注意力矩阵(虚线框)在 HBM 上实现。在外循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块并将它们加载到 SRAM。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。右图:7.6 倍加速。FlashAttention论文

矩阵乘法”运算已针对 GPU 进行了优化。您可以将此 FlashAttention 算法视为针对 GPU 优化的“注意层”运算的实现。作者将多个乘法和 softmax 运算与平铺和优化的 HBM 访问“融合”。

论文中对 FlashAttention进行了很好的概述

最近以来,PyTorch 2.0 已内置flash-attention。这是作者使用Triton 语言实现的FlashAttention 。

这个技巧可以加快训练和推理的速度

[技巧#4] 多查询注意力机制(MQA)

原始的多头注意力(MHA)在每个头中都有一个用于 K 和 V 矩阵的单独线性层。

在推理过程中,解码器中先前标记的键和值会被缓存,以防止重新计算它们,因此GPU 内存使用量会随着每个生成的标记而增加

多查询注意力机制(MQA) 是一种优化方法,建议在线性投影 K 和 V 时在所有注意力头之间共享权重,因此我们只需要保留 2 个大小为(n, k)(n, v)的矩阵。大型模型最多可以有 96 个注意力头(例如 GPT-3),这意味着使用 MQA 可以节省 96 倍的键/值解码器缓存的内存消耗。

这种优化在生成长文本时尤其有用。例如,上下文长度很大,并且要求进行长篇有意义的分析或总结。

这种方法的主要优点是显著加快了推理过程中增量注意力分数的计算速度。训练速度基本保持不变。例如,PaLM 正在使用它。

[技巧#5] 条件计算

d > n时,速度的瓶颈不是注意层,而是前馈层和投影层。减少 FLOP 的常用方法是采用某种形式的条件计算,避免将所有模型参数应用于输入序列中的所有标记。

在稀疏注意力部分,我们讨论了某些 token 比其他 token 更重要。按照同样的直觉,在CoLT5 论文中,作者将所有前馈和注意力计算分为两个分支:重度轻度。轻度层适用于所有 token,重度层仅适用于重要 token。

“轻度和重度前馈分支仅在隐藏维度上有所不同,轻度分支的隐藏维度比标准 T5 前馈层小,而重度分支的隐藏维度较大”。

事实证明,对于高达 64K 输入标记的极长序列,该方法的速度和准确性都优于现有的LongT5模型。

LLM 中 100K 上下文窗口背后的秘密
具有条件计算的 COLT5 Transformer 层概述。所有标记都由轻度注意和 MLP 层处理,而 q 个路由查询标记比 v 个路由键值标记执行更重的注意,m 个路由标记由更重的 MLP 处理。CoLT5论文

[技巧#6] 大内存 GPU

这不是技巧,而是必需品。为了适应大环境,GPU 需要大 RAM,因此人们使用 80GB A100 GPU。

结论

希望这对你有帮助!我学到了很多,希望你也学到了很多,现在我们可以猜测这些具有数十亿参数的大型语言模型是如何在前所未有的 65-100K 个标记的上下文窗口中进行训练的。

看到不同的聪明人从不同角度解决同一个问题,进行各种优化,并想出很酷的想法,真是令人鼓舞。所有这些都会得出一个有意义且优雅的解决方案。

参考

[1] Antropic推出 100K 上下文窗口
[2] MosaicML 推出的MPT-7B
[3] Google 的Palm-2 技术报告
[4] ALiBI:短训练,长测试:具有线性偏差的注意力机制可实现输入长度外推
[5] FlashAttention:具有 IO 感知的快速且内存高效的精确注意力机制
[6]多查询注意力机制:快速 Transformer 解码:一个写入头就足够了
[8]注意力机制就是您所需要的
[9]位置正弦嵌入视频
[10] FlashAttention 论文概述
[11]滑动窗口注意力机制
[12]使用稀疏注意力方法构建用于更长序列的 Transformer [13] Triton语言中的
FlashAttention实现
[14]如何使用 Triton 和 ClearML 将 HuggingFace 吞吐量提高 193%
[15] ClearML Serving
[16]分析NVIDIA Triton 推理服务器与其他推理引擎的比较
[17] COLT5:具有条件计算功能的更快长距离 Transformer
[18] LongT5:适用于长序列的高效文本到文本 Transformer
[19] PaLM
[20] BigBird注意力机制

RA/SD 衍生者AI训练营。发布者:chris,转载请注明出处:https://www.shxcj.com/archives/4558

(0)
上一篇 2024-08-09 2:43 下午
下一篇 2024-08-09 3:05 下午

相关推荐

发表回复

登录后才能评论
本文授权以下站点有原版访问授权 https://www.shxcj.com https://www.2img.ai https://www.2video.cn