LLM训练数据调试指南

关于LLM训练的大多数讨论都集中在模型和算法上。我们热衷于实验像GRPO这样的新框架,并期待下一代模型如Gemma-3Qwen-3的发布。然而,在LLM训练中区分成功与失败的主要因素是训练数据集的质量。不幸的是,与其他热门研究领域相比,这个话题受到的关注要少得多。

在本概述中,我们将提供一个以数据为中心的LLM训练调试和优化指南,强调我们可以用来迭代改进数据并开发更强大LLM的实用策略

1、LLM开发生命周期

LLM开发的关键步骤

在训练LLM时,我们遵循一个迭代且经验驱动的过程,主要由两个步骤组成(如上所示):

  1. 训练LLM。
  2. 评估LLM。

要开发一个LLM,我们只需重复这些步骤,最终得到一个在我们感兴趣的应用程序相关评估中表现良好的LLM。

LLM评估。 我们不会详细讨论评估LLM的话题,因为这个问题极其复杂。不过在高层面上,评估LLM有两种方式——要么人工评估,要么自动评估。人工评估可以通过多种方式设置;例如,在两个模型响应中选择更好的一个,或沿着多个质量维度对模型响应进行评分,如下所示。与任何其他数据标注项目一样,我们必须投入精力确保这些人工评估是高质量的并与我们试图衡量的内容保持一致。

(来自[5, 12])

在开发LLM时,人工评估是衡量质量的金标准——我们应该始终依赖人工评估来提供我们的LLM是否变得更好的明确信号。然而,人工评估也非常耗时(即需要几天或几周)!为了避免拖慢我们的迭代速度,我们必须开发自动评估指标,以提供更高效的模型质量代理度量。使用这些自动指标,我们可以在每次人工评估试验之间进行更多的模型迭代,从而更快地提高模型质量,如下所示。

在自动评估方面,通常使用两种主要技术——基准风格评估和LLM评分,如下所示。这两种策略分别测试模型在封闭式和开放式任务上的表现。

(来源)

基准风格评估(例如,多项选择题或问答对)在NLP研究的历史中一直被使用。LLM的现代基准示例包括MMLUGPQA Diamond。这些基准有封闭式的解决方案,但LLM产生的是难以评估的开放式输出。开放式评估最流行的技术是LLM-as-a-Judge,或其他相关技术(例如,奖励模型微调评分器验证器),详情请参见:使用LLM进行评估

调整数据。 一旦我们有了评估设置,就可以开始训练新模型并衡量它们的性能。对于每个新模型,我们都会进行某种干预,希望(有希望)能提升LLM的性能。传统上,AI研究人员对算法和架构非常感兴趣1,有时我们确实会调整这些细节!例如,Llama 4对其训练后流程进行了重大更改2,许多LLM正在将新算法——例如RLVR——纳入其训练流程以提高推理能力。然而,尽管有这些最新发展,大多数干预措施都与数据相关。我们调整训练数据,保持其他所有内容不变,重新训练(或继续训练)模型,然后观察新数据是否提升了模型性能。

(来自[2])

概念上最直接的数据干预就是收集更多训练数据。在LLM开发过程中收集更多数据是很常见的。例如,Llama 2报告[3]指出,模型通过多个阶段进行训练后优化,在每个阶段都会收集更多数据用于进一步训练后优化,如上所示。收集数据在概念上看起来很简单,但数据标注是一个极其复杂和细致的课题,需要正确的策略——通常还需要先前的经验——才能成功执行;详情请参见这里这里

"从人类数据中获得最大收益涉及模型的迭代训练、不断演变且高度详细的数据指令、通过数据工厂业务进行转换,以及其他各种挑战。" - RLHF书

管理数据。 在本报告中,我们不会关注收集更多数据。相反,我们将专注于管理(或调试)我们已有的数据。这是一种与人类数据收集正交的方法,如下所示。为此,我们使用各种技术来识别高质量或低质量的数据,以便修复数据集中的问题,并将训练过程聚焦于最高质量的数据。

获取高质量数据的两个方向(来源

鉴于大多数对LLM质量的干预都与数据相关,数据管理是一个极其重要的课题;例如,有几家初创公司大量优秀论文专注于这个主题。尽管这对LLM训练过程至关重要,但数据相关的主题在AI研究中通常被低估。优化数据并不是一个引人注目或流行的主题,但它往往是训练LLM时区分成功与失败的关键因素。

2、如何管理数据?

简单来说,有两种管理数据的方式:

  1. 直接查看数据。
  2. 使用模型输出调试训练数据。

例如,我们可以通过手动检查或基本的搜索和启发式方法来管理和调试数据。此外,我们可以使用另一个模型来分析我们的数据;例如,标记、分类、分配质量分数等等。所有这些策略都与我们正在创建的下游模型无关——我们直接查看训练数据。然而,一旦我们训练了一个模型,我们就可以通过调试LLM的输出进一步推动数据管理过程,如下所示:

  • 识别差的模型输出。
  • 找到(可能)导致这些输出的数据问题。
  • 通过某种干预修复数据。
  • 重新训练模型。

调试策略。 在本概述中,我们将上述两种策略称为数据聚焦型管理和模型聚焦型管理。有很多术语可以用来指代这些概念,这种命名法肯定不完美;例如,数据聚焦型管理仍然可能涉及模型的使用,我们只是用模型来分析数据,而不是用数据来训练模型。不过,我们将在整篇文章中使用这套术语以保持讨论的清晰和一致。

在讨论这些想法时,我们应该记住,数据聚焦型和模型聚焦型调试并非互斥。事实上,我们几乎应该同时利用两者。数据聚焦型管理不需要训练任何模型,这在LLM开发的早期阶段非常有用。有经验的研究人员在进行任何建模之前会花大量时间分析和理解他们的数据。

随着时间的推移,我们会继续执行这种数据聚焦型分析,但一旦我们训练了模型,新的分析途径就变得可能。要调试和改进LLM,我们必须开发一种多方面的方法,使我们能够更深入地理解我们的模型、数据以及它们之间的联系。

3、数据聚焦型管理:观察数据

要深入了解我们的数据,首先要手动查看数据。在手动检查数据时,我们会开始注意到——并在某些情况下修复——数据中的重要问题和模式。然而,要将这一管理过程扩展到超越我们自己的判断力,我们需要使用基于启发式方法或其他机器学习模型的自动化技术。

(来源)

手动检查。 调试LLM的第一步就是查看模型的训练数据。这应该在我们开始训练任何模型之前进行,并在整个模型开发过程中持续进行。 手动数据检查非常耗时(而且不一定是最有趣的!),但它是LLM开发的重要组成部分。通过花时间手动检查数据,我们能够更好地理解数据,进而更好地理解我们的模型。如果你询问任何LLM研究人员,他们很可能会确认他们花了很多时间手动检查数据。这种不受欢迎的活动是训练LLM成功的关键因素——它不能(也不应该)被避免!

手动数据检查的主要限制在于它不可扩展——作为研究人员,我们只能手动检查有限的数据。一旦我们进行了足够的手动检查3以充分理解数据,就需要开发更好的策略来扩展数据检查工作。

启发式过滤。 手动检查会揭示数据中的许多问题和有趣的模式。例如,我们可能会注意到某些词被非常频繁地重复使用,如下所示。为确保模型不反映数据中的次优模式,我们可以使用启发式方法来查找匹配这些模式的训练示例并进行过滤(或修改)。例如,查找重复使用相同词汇的数据可以通过简单的字符串匹配来完成。这里,我们使用基本的启发式方法来解决数据中的明显限制。

(来源)

还有许多其他数据检查和过滤的启发式方法值得考虑。例如,我们可能会注意到某些数据来源质量更高或具有有用的特性。针对这种情况,我们可以在训练中强调这些数据4,甚至从该来源获取更多数据。类似地,我们可能会注意到数据子集中存在格式化问题,可以通过正则表达式来识别或修复。根据手动检查阶段的观察,可能有几乎无限数量的启发式检查或修复需要应用于训练数据集。

基于模型的过滤。 如果观察到的问题无法通过启发式方法解决,那么我们可以借助机器学习模型来修复这些问题。fastText分类器因其高效性被广泛用于LLM数据过滤——即使在预训练规模下也能高效运行。fastText模型用于LLM数据过滤的具体例子包括语言识别(例如,过滤掉非英语数据)或识别有毒内容。不过,可以轻松训练自定义fastText模型来处理各种定制化的过滤任务。我们只需i) 在想要识别的数据样本上训练模型,ii) 使用模型识别这些数据,iii) 然后移除或保留被识别出的数据,如下所示。

(来源)

我们还可以使用其他类型的模型进行数据过滤。例如,LLM-as-a-Judge风格的模型通常用于过滤数据和创建合成数据。Constitutional AI是一个流行的例子,它使用LLM评判者来创建合成偏好对,Llama 4则使用LLM评判者从其监督微调数据集中移除较容易的示例。我们可以应用类似的方法来识别数据中的任意属性和模式——通常准确度相当高——以达到过滤目的。

"我们使用Llama模型作为评判者,移除了超过50%标记为"容易"的数据,并在剩余的较难数据集上进行轻量级SFT。" - 来自[13]

这些较大的模型相对于fastText模型来说效率要低得多,这限制了它们的使用范围(通常在训练后阶段)。如果将BERT-base(比一些最大的现代LLM小约10,000倍)与fastText模型进行比较,效率和所需硬件的差异是巨大的,如下所示。尽管如此,开发更复杂的数据管理方法和模型是当前AI研究中最有影响力的话题之一。

使用fastText与BERT-base进行数据过滤的对比(来源)

4、模型聚焦型管理:调试LLM的输出

一旦我们开始在数据上训练LLM,我们就可以利用这些LLM来调试训练数据集中的问题。模型聚焦型管理的思路很简单:

  1. 识别模型产生的有问题的或不正确的输出。
  2. 搜索可能导致这些输出的训练数据实例。

有问题的输出的识别通过我们的评估系统来处理。我们可以让人类(甚至我们自己!)通过手动检查来识别差的输出,或者通过自动评估设置高效地查找不正确或得分低的输出。一旦识别出这些有问题的输出,调试LLM就变成了一个搜索问题——我们希望找到与这些差输出相关的训练示例。在本节中,我们将介绍几种常用方法,最后介绍Ai2最近开发的一种低成本高效的数据追踪方法——OLMoTrace [2]。

4.1 搜索训练数据

搜索相关的训练数据与任何其他搜索问题类似,如上所述。唯一的区别是,我们的查询是来自LLM的输出,而不是我们在搜索框中输入的内容。但是,所有相同的搜索技术都可以用来解决这个问题。关于此主题的深入探讨,请查看这个概述。在本节中,我们将简要介绍搜索的关键概念,以及它们如何应用于追踪训练数据。

词汇搜索。 在深度学习普及之前的很多年里,大多数搜索引擎都是纯词汇的,意味着它们依赖关键词(或n-gram)匹配来查找与查询相关的文档。为了高效地找到这些匹配,我们使用一种叫做倒排索引的数据结构。通过计算每个查询与文档之间的匹配数,并考虑每个匹配n-gram的唯一性,我们可以为每个文档得出一个相关性分数。最常用的算法是BM25,其计算公式如下所示。

BM25分数计算公式

尽管这些细节看起来可能很复杂,但我们可以通过像rank_bm25bm25s这样的Python包轻松实现BM25驱动的搜索。使用这些包,我们可以在Python中为数据构建搜索索引,并开始运行搜索,如下面的代码示例所示。可以看出,这个功能很容易原型化并开始使用,无需太多努力!

from transformers import AutoTokenizer
from rank_bm25 import BM25Okapi

tok = AutoTokenizer.from_pretrained(<your tokenizer>)

corpus = [
    "Here is a training example",
    "Here is another training example...",
]

tokenized_corpus = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)

语义搜索。 尽管词汇搜索强大且高效,这种技术仍然依赖于关键词匹配——语义匹配(即不同词语具有相似含义)无法被该框架捕获。如果我们要处理语义匹配,就需要使用某种形式的向量搜索,如下所示。

一个简单的向量搜索流程

在向量搜索中,我们使用一个嵌入模型为要搜索的每个文档生成嵌入向量。然后,我们将所有这些嵌入向量存储在向量数据库中,这使我们能够使用像层次化可导航小世界(HNSW)这样的算法高效搜索相似的嵌入向量。然后,我们只需嵌入查询并在索引中搜索相似的嵌入向量,就能找到与查询语义相似的文档!这正是检索增强生成(RAG)所做的,用于检索相关文本块以添加到LLM的上下文中;详情请参见这里

双编码器与交叉编码器的区别

上面概述的语义搜索系统使用双编码器,它为每个文档和查询生成独立的嵌入向量——通过余弦相似度分数进行匹配。不过,我们也可以使用交叉编码器,它将文档和查询同时作为输入,输出一个单一的相似度分数。这两种策略的差异如上图所示。公开的资源库中有多种预训练的双编码器和交叉编码器可用,可以进行微调或开箱即用;更多详情请参见这里

现代搜索系统结合了所有这些技术。首先使用双编码器和(BM25)词汇搜索的混合方法来高效检索与查询最相关的文档。然后,使用交叉编码器对检索到的文档进行细粒度排序,将最相关的文档排在列表顶部,如下所示。所有组件都可以随着搜索引擎的使用,通过收集的数据进行微调,以持续提高准确性。

现代AI驱动的搜索框架

将搜索应用于调试。 现在我们已经了解了搜索系统的基础知识,也可以将这些想法应用于调试LLM输出。不过,调试LLM输出的两个特殊考虑因素使其与标准搜索应用不同:

  • LLM训练数据集可能非常庞大(数十万亿个token),这会限制某些技术的使用。
  • 根据使用场景,LLM的输出以及LLM训练所用的文档可能非常长。

如果我们要追踪大型数据集,使用像向量搜索这样的技术——尽管并非不可能——既耗时又昂贵。我们必须首先为整个数据集生成嵌入向量,然后将这些嵌入向量存储在向量数据库中以便搜索。这个过程需要大量的前期准备(包括创建大规模数据管道!),入门门槛较高。

更进一步说,LLM的输出和训练文档可能非常长,这意味着我们需要以不同的方式处理这个搜索问题。与其使用整个输出作为搜索查询,不如考虑输出中较短的片段,并在训练数据中搜索类似的片段。理想情况下,我们想要开发一种追踪训练数据的技术,它应该具备以下特点:

  • 相对容易设置。
  • 在大规模数据集上高效。
  • 能够在(较短的)片段级别上操作。

4.2 Infini-gram

Infini-gram:将无界n-gram语言模型扩展到万亿token规模 [1]

"我们不预先计算n-gram计数表(那将非常昂贵),而是开发了一个名为infini-gram的引擎——由后缀数组驱动——可以在毫秒级延迟内计算∞-gram(以及任意n的n-gram)概率。" - 来自[1]

要理解如何高效追踪大规模数据集,我们需要首先理解infini-gram [1]的概念。简单来说,infini-gram是将n-gram推广到任意大的N值。正如我们将看到的,用于计算infini-gram概率的数据结构也可以用来(非常高效地)定位和统计大规模数据集中任意长度的文本片段。这个属性对于模型聚焦型管理和调试非常有用!

从文本序列创建n-gram

什么是n-gram语言模型? n-gram就是一个有序的N个token(或单词)集合。给定一个文本序列,我们可以将其分解为n-gram,如上所示,其中我们选择N = 3。如果我们将整个文本数据集分解为n-gram,实际上可以通过简单计算给定n-gram在数据集中出现的次数来计算其概率,如下所示。

计算n-gram概率

所有这些计数通常预先计算并存储在计数表中,允许我们快速查询n-gram概率并计算上述表达式。实际上,我们可以使用n-gram概率形成一个简单的语言模型!要使用n-gram预测序列中的下一个token,我们只需:

  1. 查看序列中最后的N - 1个token。
  2. 获得在给定前N - 1个token的情况下每个可能n-gram的概率。
  3. 像任何其他语言模型一样采样下一个token

n-gram的局限性。 实际上,n-gram语言模型在文本生成方面并不出色——你不可能通过计数n-gram来制造一个强大的聊天机器人。虽然这对任何N值都成立,但限制n-gram语言模型性能的关键问题之一是,n-gram计数表的大小随着N呈(近乎)指数级增长。因此,大多数n-gram语言模型只能使用较小的N值——例如,N = 5 是常见设置——并且捕获有意义的、长上下文语言分布的能力有限,如下所示。

(来自[1])

此外,n-gram语言模型还面临稀疏性问题。某些n-gram可能不会出现在我们的数据中,迫使我们退回到更小的n-gram来计算概率——这个概念通常被称为n-gram"回退"。在回退到较小n-gram时形成有效的概率估计实际上相当复杂

让n-gram重新变得相关。 在[1]中,作者提出了一种n-gram语言模型的变体——称为infini-gram(或∞-gram)——它能更好地与现代LLM配合。相对于标准n-gram,infini-gram做了两个关键改变:

  1. 它们像其他现代LLM一样在大量文本数据(万亿级token)上进行训练,从而缓解了稀疏性问题。
  2. 在计算n-gram概率时,N的值可以任意大,从而捕获数据中更有意义的分布。

什么是∞-gram? 通过这些改变,infini-gram解决了我们上面讨论的n-gram语言模型的两个最大问题。这是如何实现的? 假设我们有一个文本序列w。要计算token i的infini-gram,我们考虑序列中所有在token i之前的token,如下所示。

计算infini-gram概率

在这个等式的左侧,infini-gram概率以序列的整个先验上下文为条件,这与之前不同。然而,等式的右侧与n-gram概率完全匹配!n-gram和infini-gram之间的关键区别在于我们如何选择 N 的值

对于n-gram,N是一个(固定的)超参数。相比之下,infini-gram使用回退过程来动态选择N。更具体地说,我们使用可能的最大N——序列中所有前面的token——来测试该表达式的分母,并不断将N减一,直到分母非零,如下所示。

"一旦分母变为正数,我们就停止回退,此时分子可能仍然为零……有效n等于提示词在训练数据中出现的最长后缀的长度加一。" - 来自[1]

如果我们将w'定义为w中直到(包括)token i - 1的子序列,那么这个回退过程就是简单地找到w'在数据集中存在的最长后缀。然后,我们使用通过回退找到的N值,用之前的标准n-gram概率表达式来计算infini-gram概率。

计算∞-gram概率。 要计算infini-gram概率,我们不能像以前那样简单地预先计算计数并存储在表中。N的值是无界的,而且infini-gram在[1]中是在LLM规模的数据集上训练的——这样的计数表将极其庞大。相反,我们使用一种叫做后缀数组的数据结构来创建一个高效计算infini-gram概率的引擎。

六个字符的玩具序列上的后缀数组(来自[1])

后缀数组的概念如上图所示。给定一个长度为L的文本序列w,后缀数组的构建过程如下:

  1. 提取该序列的每个后缀(共有L个)。
  2. 字典序5对后缀进行排序。
  3. 将每个排序后缀的原始索引(排序前)存储在一个列表中——这就是后缀数组

考虑w'w中从token i到token j的任意子数组,其中i < j。任何以w'开头的后缀都会因为数组按字典序排序而连续存储在后缀数组中。利用这个属性,我们可以高效地计算w'w中的计数。我们只需找到数组中第一个和最后一个以w'为前缀的后缀的索引,w'w中的计数就是这两个索引之差。如果我们可以计算w'的计数,我们就可以计算任意infini-gram概率——这个操作可以用来找到 N 并在infini-gram概率表达式中计算两个计数

文本token上的后缀数组(来自[1])

用于LLM的∞-gram。 在LLM的上下文中,我们的序列w是LLM的整个token化训练数据集,其中文档边界用固定的分隔符token标记6,如上所示。这个序列将非常庞大——现代LLM在数万亿个token上进行训练——但后缀数组可以处理这种规模的数据7

"在推理过程中,整个infini-gram索引可以保持在磁盘上,从而最大限度地减少所需的计算资源(不需要GPU,CPU/RAM需求也很低)……我们最优化后的infini-gram引擎可以在平均不到20毫秒的延迟内计算给定n-gram的计数。它可以在40毫秒内计算n-gram语言模型的概率和下一个token分布,在200毫秒内计算∞-gram。" - 来自[1]

例如,[1]中建立在5T token数据集上的后缀数组消耗约35Tb内存。构建这个后缀数组大约需要48小时,构建完成后整个后缀数组可以存储在磁盘上——即使在计算infini-gram概率时也是如此。由此产生的infini-gram引擎可以计算超过两千万亿个唯一n-gram的概率。然而,在这种规模的数据集上检索给定n-gram的计数仍然只需要大约20毫秒!

在实践中使用∞-gram。 完全理解infini-gram背后的思想需要一些时间。幸运的是,整个infini-gram项目——Ai2的任何其他项目一样——是完全开源的!有大量开源工具可用于在Python中使用infini-gram。详情请参见项目网站

%pip install infini_gram 
python -m infini_gram.indexing 
    --data_dir <path to data>
    --save_dir <path to save index>
    --tokenizer llama  # also supports gpt2 and olmo
    --cpus <cpus available>
    --mem <memory available (in Gb)>
    --shards 1  # increase if N > 500B
    --add_metadata 
    --ulimit 1048576

与本概述最相关的工具是infini-gram Python包。几个开放的LLM训练数据集已经在该包中预索引,但我们也可以使用该包通过上述命令在我们自定义的数据集上构建infini-gram索引。索引可用后,我们可以使用infini-gram Python包高效运行各种搜索和计数操作;以下为示例和这里的更多详情。

from infini_gram.engine import InfiniGramEngine
from transformers import AutoTokenizer

# instantiate tokenizer (must match tokenizer used for indexing)
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    add_bos_token=False,
    add_eos_token=False,
)

# connect to infini-gram engine
engine = InfiniGramEngine(
    index_dir=<path to index>,
    eos_token_id=tokenizer.eos_token_id,
)

# sample n-gram / sequence
inp = "This is my sample n-gram sequence."
inp_ids = tokenizer.encode(inp)

# find matching n-grams in dataset
result = engine.find(input_ids=input_ids)

# n-gram count
result = engine.count(input_ids=inp_ids)

# n-gram probability
result = engine.prob(
    prompt_ids=inp_ids[:-1],
    cont_id=inp_ids[-1],
)

# next token distribution
result = engine.ntd(prompt_ids=inp_ids)

# infini-gram probability
result = engine.infgram_prob(
    prompt_ids=inp_ids[:-1],
    cont_id=inp_ids[-1],
)

4.3 OLMoTrace

OLMoTrace:将语言模型输出追溯到数万亿训练token [2]

(来自[2])

OLMoTrace [2]开创了一种新颖的方法,可以高效地将LLM的输出归因于其训练数据中的示例。该方法部署在Ai2演示平台(如上所示)中,可以在几秒钟内执行追踪,检索与LLM输出相关的训练文档。考虑到LLM是在海量数据集上训练的,我们可能会想知道这种实时追踪是如何实现的。幸运的是,我们已经知道了答案:infini-gram

"OLMOTRACE的目的是为用户提供一个工具,用来探索LM可能从哪里学会了生成特定的词序列,重点关注LM输出与训练数据之间最直接的逐字匹配联系。" - 来自[2]

追踪策略。 OLMoTrace的核心思想是找到既存在于模型输出中、又存在于训练数据集中的长且唯一的token序列示例。输入一个提示词和LLM的响应,OLMoTrace将返回以下内容:

  • LLM响应中发现的一组值得注意的文本片段。
  • 与每个响应片段相关联的最相关训练文档片段列表。

与向量搜索不同,模型输出和训练数据之间的这些匹配必须是逐字精确的。精确的token匹配可以通过后缀数组快速识别,正如上一节讨论的那样。然而,确保识别并返回最佳匹配文档需要一个建立在标准infini-gram功能之上的四步算法。

(步骤1)最大匹配片段。 在对LLM的响应进行token化之后,我们找出该响应中满足三个属性的所有文本片段:

  1. 存在性:该片段在训练数据中存在精确匹配。
  2. 最大性:该片段不是另一个匹配片段的子片段。
  3. 自包含性:该片段是完整的;例如,不以不完整的单词开头或结尾,且片段中间不包含标点符号。

这些属性在下面的图中进行了说明。这里,我们可以看到有三个匹配片段。然而,除了一个——用绿色标出的——之外,所有片段都由于不是i) 最大的或ii) 自包含的而被移除。

最大和自包含片段的图示

朴素地计算最大片段是低效的,但[2]中的作者提出了一种更高效的算法,该算法依赖于infini-gram索引中的find操作。给定一个token序列作为输入,find操作返回:

  • 索引中匹配片段的数量。
  • 可用于查找匹配数据片段的段范围8

然而,如果返回的计数为零——表明我们的数据中没有该序列的精确匹配——find操作仍将返回一个(空的)段范围。由于后缀数组是按字典序排序的,该范围的索引对应于该序列在数据集中存在的最长匹配前缀。

# run find operation with infini-gram engine
result = engine.find(input_ids=inp_ids)

"""
### .find() output example (match): 
    {
        'cnt': 10,
        'segment_by_shard': [(13693395, 13693405)],
    }

### .find() output example (no match):
    {
        'cnt': 0,
        'segment_by_shard': [(85267640, 85267640)],
    }
"""

# lookup training documents from .find()
rank_start, rank_end = result['segment_by_shard'][0]
ranks = [r for r in range(rank_start, rank_end)]
for r in ranks:
    docs = engine.get_doc_by_rank(
        s=0,  # assumes suffix array has a single shard
        rank=r,
        max_disp_len=len(inp_ids) * 5,  # size of doc chunk
    )
    doc_text = [tokenizer.decode(d['token_ids']) for d in docs]
    print(f'Number of documents: {len(docs)}')
    print(f'Matching document: {doc_text[0]}')

find操作的这一属性在[2]中被用来创建一种高效的片段匹配算法。如下面的图所示,该算法通过对输入序列的每个后缀运行一次find操作来工作,为每个后缀产生最长匹配前缀。一旦所有这些匹配片段被识别出来,我们可以再次遍历这个列表,移除任何不是最大或自包含的匹配片段。

(来自[2])

(步骤2)片段过滤。 如果如上所述计算出的最大片段列表很长,我们需要某种策略来识别最有用和最相关的片段。为此,[2]中的作者根据片段的单字概率(越低越好)——或片段中每个token的单字概率的乘积——对片段进行评分。给定token的单字概率通常为所有token预先计算并存储在缓存中,其计算公式如下所示。

计算token的单字概率

在[2]中,作者按片段的单字概率对片段进行排序,并只保留列表中前K个片段,其中K = ceil(0.05 x L)L为序列长度。

(步骤3-4)合并片段并获取文档。 为避免杂乱,OLMoTrace会将重叠的片段合并在一起。然后检索这些最终片段对应的文档。但是,与每个片段关联的文档数量可能很大,因此我们必须进行子选择;例如,[2]中的作者为每个片段保留十个文档。为了找到最相关的文档,我们可以根据LLM输出与检索到的文档之间的BM25分数进行排序。

"为了优先显示最相关的文档,在文档面板中,我们按BM25分数降序对所有文档进行排序。每个文档的BM25分数是通过将检索到的文档集合视为语料库,并将用户提示和LM响应的拼接作为查询来计算的。" - 来自[2]
(来自[2])

示例实现。 OLMoTrace的推理流程如上图所示。为了更好地理解其工作原理,让我们使用infini-gram Python包(快速)实现核心功能。要构建infini-gram索引,我们需要将LLM的所有训练数据放入一个单独的目录。infini-gram包期望数据格式化为一个或多个.jsonl文件,每个文件包含textmetadata字段,如下所示。.jsonl文件的每一行对应训练数据集中的一个文档。

{
    'text': 'This is a training sequence for our LLM...',
    'metadata': {
        'source': <url>,
        'category': 'general',
        'year': 2025,
        ...
    },
}

数据格式化完成后,我们可以如前所述构建infini-gram索引。此外,OLMoTrace要求我们预先计算所有token的单字概率。以下代码实现了这两个步骤。这段代码假设我们使用Llama 2 tokenizer执行追踪,并且我们的infini-gram索引只需要一个分片。底层的tokenizer可以修改,在处理非常大的数据集(即超过500B token)时可能需要支持索引的多个分片。

import os
import json
from collections import Counter
import tempfile

from transformers import AutoTokenizer

# load tokenizer / data
enc = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False)
data_rows = [{'text': 'here is some training data'}, ...]

# compute / save unigram probabilities
all_toks = []
for x in data_rows:
    all_toks.extend(enc.encode(x['text']))
total_toks = len(all_toks)
tok_count = Counter(all_toks)
unigram_probs = {}
for tid in tok_count:
    cnt = tok_count[tid]
    unigram_probs[tid] = cnt / total_toks
with open(<save path>, 'w') as json_file:
    json.dump(unigram_probs, json_file, indent=4)

# build infinigram index
data_dir = <path to data>
save_dir = <save index here>
temp_dir = tempfile.TemporaryDirectory()
command = (
    f"python -m infini_gram.indexing --data_dir {data_dir} "
    f"--temp_dir {temp_dir.name} --save_dir {save_dir} "
    f"--tokenizer llama --cpus 12 --mem 64 --shards 1 "
    f"--add_metadata --ulimit 100000 "
)
print(command)
os.system(command)
temp_dir.cleanup()

现在infini-gram索引已经构建完成,我们可以按照[2]中OLMoTrace提出的算法,在训练数据集上追踪文本序列——如下面的代码所示。这段代码返回一组片段及其关联的文档,附带有训练语料库的元数据。

import ast
import math
import random

from infini_gram.engine import InfiniGramEngine
from transformers import AutoTokenizer

def compute_longest_prefix(query, doc):
    """helper function for computing longest prefix of query that exists
    within a document"""

    def shared_prefix_length(list1, list2):
        prefix_length = 0
        for elem1, elem2 in zip(list1, list2):
            if elem1 == elem2:
                prefix_length += 1
            else:
                break
        return prefix_length

    first_id = query[0]
    start_idx = [index for index, value in enumerate(doc) if value == first_id]
    longest_prefix = 0
    for si in start_idx:
        longest_prefix = max(
            longest_prefix,
            shared_prefix_length(query, doc[si:]),
        )
    return longest_prefix

# setup
enc = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False)
engine = InfiniGramEngine(index_dir=<path to index>, eos_token_id=enc.eos_token_id)
unigram_probs = {1: 0.5, 2: 0.5} # load pre-computed probabilities

# LLM output / query to search
generation = 'Here is the output of the LLM that we want to search for in our data.'
gen_ids = enc.encode(generation)


"""
Step One: find maximal matching spans
"""
L = len(gen_ids)
max_doc_toks = len(gen_ids) * 2 # size of spans to retrieve in documents

# find longest prefix match for every suffix in the query
spans = []
for start in range(len(gen_ids) - 1):
    _suffix = gen_ids[start:]
    _suff_res = engine.find(input_ids=_suffix)

    # if no match, get the longest matching prefix using find result
    if _suff_res['cnt'] == 0:
        _shards = _suff_res['segment_by_shard']
        assert len(_shards) == 1 # assume only one shard
        _doc_ids = engine.get_doc_by_rank(
            s=0, # assume only one shard
            rank=_shards[0][0],
            max_disp_len=max_doc_toks,
        )['token_ids']
        matched_toks = compute_longest_prefix(_suffix, _doc_ids) # get longest matching prefix
    elif _suff_res['cnt'] > 0:
        matched_toks = len(_suffix)
    spans.append((start, start + matched_toks))

# remove partial and non-self-contained spans
full_spans = []
for start, end in spans:
    span_ids = gen_ids[start: end]
    span_text = enc.decode(span_ids)

    # check for internal punctuation
    has_internal_punc = False
    punc_chars = "!.\n"
    for ch in span_text[:-1]:
        if ch in punc_chars:
            has_internal_punc = True
            break
    if has_internal_punc:
        continue

    # check if first token is a continuation of a word
    first_tok_id = span_ids[0]
    first_tok = enc.convert_ids_to_tokens(first_tok_id)
    if first_tok[0] != '▁': # assumes Llama 2 token format
        continue

    # no sub-token follows the last token
    if end < len(gen_ids) and tokenizer.convert_ids_to_tokens(gen_ids[end])[0] != "▁":
        continue
    full_spans.append((start, end, span_ids, span_text))

# remove non-maximal spans
maximal_spans = []
max_end_pos = -1
full_spans = sorted(full_spans)
for start, end, ids, text in full_spans:
    if end > max_end_pos:
        maximal_spans.append((start, end, ids, text))
        max_end_pos = end


"""
Step Two: filter to keep long / unique spans
"""
K = math.ceil(0.05 * L)
assert K > 0
filt_spans = []
for start, end, ids, text in maximal_spans:
    span_uni_prob = [unigram_probs.get(_id) for _id in ids]
    span_uni_prob = math.prod(span_uni_prob)
    filt_spans.append((start, end, ids, text, span_uni_prob))
filt_spans = sorted(filt_spans, key=lambda x: x[-1])
filt_spans = filt_spans[:K]
filt_spans = sorted(filt_spans) # sort based on start position again


"""
Step Three: retrieve Enclosing Docs
"""
docs_per_span = 10
span_to_docs = defaultdict(list)
for i, (start, end, ids, text, uni_prob) in enumerate(filt_spans):
    # run retrieval in infinigram index to get documents
    span_res = engine.find(input_ids=ids)
    assert span_res['cnt'] > 0
    assert len(span_res['segment_by_shard']) == 1 # assume only one shard

    rank_start, rank_end = span_res['segment_by_shard'][0]
    ranks = [r for r in range(rank_start, rank_end)]
    if len(ranks) > docs_per_span:
        # retrieve fixed number of documents for each span
        ranks = sorted(random.sample(ranks, docs_per_span))

    # NOTE: we can instead rank documents by BM25 score here!
    for r in ranks:
        _doc = engine.get_doc_by_rank(
            s=0,
            rank=r,
            max_disp_len=max_doc_toks,
        )
        _doc_meta = ast.literal_eval(_doc['metadata'])['metadata']
        _doc_text = enc.decode(_doc['token_ids'])
        _doc_data = {
            "text": _doc_text,
            **_doc_meta
        }
        span_to_docs[i].append(_doc_data)


"""
Step Four: merge overlapping spans
"""
# get indices of spans to merge together
merged_spans = [[0]]
curr_idx = 0
curr_start = filt_spans[0][0]
curr_end = filt_spans[0][1]
for i, next_span in enumerate(filt_spans[1:]):
    start = next_span[0]
    end = next_span[1]
    if start < curr_end:
        curr_end = max(curr_end, end)
        merged_spans[curr_idx].append(i + 1)
    else:
        curr_start, curr_end = start, end
        curr_idx += 1
        merged_spans.append([i + 1])
assert len(merged_spans) == curr_idx + 1

# merge spans into a final set
final_spans = []
for ms in merged_spans:
    all_docs = []
    docs_per_merged_span = math.ceil(docs_per_span / float(len(ms))) # subsample docs for spans being merged
    for i in ms:
        # take top docs from each span being merged
        all_docs.extend(span_to_docs[i][:docs_per_merged_span])
    _spans = [filt_spans[i] for i in ms]
    start = min([x[0] for x in _spans])
    end = max([x[1] for x in _spans])
    text = enc.decode(gen_ids[start: end])
    final_spans.append({
        "start": start,
        "end": end,
        "text": text,
        "docs": all_docs,
    })


"""
Step Five: observe tracing results
"""
docs_to_print = 5
print(f'Query Text: {enc.decode(gen_ids)}')
for i, sp in enumerate(final_spans):
    print("\n" + "="*20 + f" SPAN {i + 1} / {len(final_spans)} " + "="*20)
    print(f"Span Text: {sp['text']}\n")
    for j, doc in enumerate(sp['docs']):
        print("-"*10 + f" Document {j + 1} / {len(sp['docs'])} " + "-"*10)
        for k in ['text', 'movie_id', 'src_lang', 'start_frame', 'end_frame']:
            if k == 'text':
                v = doc[k].replace('\n', ' ')
            else:
                v = doc[k]
            print(f"- {k} --> {v}")

正如我们所看到的,OLMoTrace的核心功能并不复杂——大部分复杂代码已经被infini-gram包封装好了!如果你感兴趣,我强烈建议你在自己的模型和数据上测试这段代码,感受一下它能返回什么样的结果!

OLMoTrace使用场景(来自[2])

OLMoTrace的应用。 OLMoTrace专门用于查找LLM输出与其训练数据之间精确匹配的长且唯一的片段。精确匹配是查找可能对LLM的某个输出有贡献的训练数据的有用代理。在[2]中,考虑了几种不同的使用场景:

  • 事实核查:将LLM做出的陈述与其训练数据中的类似陈述进行比较。
  • 创意表达:检查LLM的"创意"输出是否真的具有创意,还是直接从训练数据中复制而来。
  • 推理能力:检查LLM是否从训练数据中复制了用于解决可验证问题(例如,数学)的推理过程。

在每种情况下,我们都可以通过追踪LLM的输出,找到训练数据中具有显著逐字匹配的区域,从而了解关于LLM的新信息。

4.4 推理模型与未来研究

LLM训练的阶段(来自[4, 5, 6])

扩展到推理模型。 如上所示,LLM通常经过几个阶段的训练,每个阶段都有独特的数据风格:

尽管数据格式不同,我们可以将OLMoTrace应用于每个训练阶段,只需做少量修改!我们可以轻松地在监督示例和偏好对上构建infini-gram索引(尽管我们可能希望区别对待偏好对中的正面和负面补全)。对于RLVR,我们可能需要更深入地思考数据应该如何被追踪

在使用RLVR训练LLM时,我们有一个带有可验证解的数据集;例如,一个带有已知解的数学问题或一个带有测试用例的编程问题。我们可以轻松检查LLM是否正确解决了这些问题(例如,通过字符串匹配或更健壮的方法),如上所示。然后,模型通过大规模RL训练驱动的自我进化过程,自行学习如何解决这些问题,如DeepSeek-R1 [7]所展示的。

"我们探索了LLM在没有任何监督数据的情况下发展推理能力的潜力,重点关注它们通过纯强化学习过程的自我进化。" - 来自[7]

在RL训练期间,我们在[7]中看到LLM学会了输出复杂的思维链——有时长达数千个token!——以提高它们的推理能力。然而,如果我们想索引这些推理轨迹,就会遇到一个有趣的问题。也就是说,推理轨迹实际上并不是我们训练数据的一部分——它们是在RL训练过程中由LLM生成的

(来自[7])

类似地,LLM生成的补全由奖励模型排序,并用于RLHF期间的策略更新;详情请参见此处的解释。如果我们想捕获在RL训练期间——包括RLHF和RLVR——学习的模式,我们必须跟踪LLM在训练期间生成的补全。有了这些补全,我们可以像其他训练数据一样进行索引,将它们添加到infini-gram索引中,并使用OLMoTrace进行追踪。

相关(和未来)研究。 尽管OLMoTrace很有用,但精确匹配并不能保证因果关系——LLM生成某个输出可能有多种原因。仅仅因为我们找到了与LLM输出相似的训练数据,并不意味着这些数据必然导致了该输出。

为了更深入地理解LLM的输出,几个并行研究方向正在探索替代的可解释性策略。例如,最近发表了大量关于教LLM在生成输出时引用来源的论文[8, 9, 10],如下所示。

(来自[8])

这种引用来源的能力可以融入到LLM的标准训练过程中——例如,预训练[8]或RLHF [9]——使模型学会何时以及如何为其答案提供证据。然而,仍然不能保证这些引用真正解释了输出是如何生成的。

机械可解释性领域旨在研究神经网络内部结构,以理解为什么它们会产生某些输出。虽然深度神经网络通常被视为黑盒,但当在微观层面(即小规模的权重集合)上研究时,我们可以发现这些网络中许多重复的电路和特征。例如,视觉网络往往有专门用于检测曲线、边缘等的单元。

机械可解释性的话题主要由Anthropic推广。在最近的一份报告中,研究人员使用字典学习对Claude Sonnet中的特征进行了大规模研究。如上所示,这项研究发现了数百万个用于高级概念的特征,例如人物、地点、代码中的错误等。

"我们已经确定了Claude Sonnet(我们部署的大型语言模型之一)内部如何表示数百万个概念。这是首次对现代、生产级大型语言模型的内部进行详细观察。" - 来自[11]

此外,作者分析了特征之间的"距离",发现了一些有趣的属性;例如,金门大桥特征与恶魔岛特征很接近。这类研究虽然仍处于起步阶段,但可以说是真正理解LLM为何以及如何产生某些输出的最有前景的途径。

5、结束语

正如我们所了解到的,优化训练数据集是LLM训练过程中最有影响力和最重要的方面之一。要有效地管理和调试数据,我们应该从观察数据本身开始——而不是训练模型!首先,我们应该手动检查数据,并理解其各种属性、模式和特点。为了扩展手动检查过程,我们可以依赖启发式方法(在可能的情况下)和机器学习模型;例如,fastText或LLM评判者。这种数据聚焦型管理过程专注于在训练任何LLM之前修复问题并提高数据质量!

"我注意到的一个模式是,优秀的AI研究人员愿意手动检查大量数据。更重要的是,他们构建了能够快速手动检查数据的基础设施。虽然这并不光鲜,但手动检查数据能让你对问题产生宝贵的直觉。" - Jason Wei

一旦我们开始训练LLM,就可以利用LLM的输出来发现数据中的问题。更具体地说,我们可以:

  1. 通过评估框架识别有问题的LLM输出。
  2. 将这些输出追踪到训练数据中的相应区域。

虽然我们可以使用标准的搜索技术——如词汇搜索或向量搜索——来追踪数据,但也有专门为LLM开发的专用追踪技术,如OLMoTrace [2]。这些技术易于(且快速)设置,信息量丰富,并且可以扩展到任意大的数据集,使其成为调试LLM训练数据集的非常实用的选择


原文链接: A Guide for Debugging LLM Training Data

汇智网翻译整理,转载请标明出处