TabFM:表格数据零样本基础模型

表格数据是企业数据基础设施的支柱,为关键预测性机器学习应用的很大一部分提供动力。从预测客户流失到识别金融欺诈,表格回归和分类任务无处不在。多年来,有监督的树基算法如 AdaBoostXGBoost随机森林等,一直在这一领域占据主导地位,在结构化数据上提供了稳健的性能。

然而,部署这些传统模型的生命周期带来了显著的瓶颈。为一个新数据集拟合一个 XGBoost 模型不仅仅是执行一个 .fit() 步骤那么简单;它不可避免地需要繁琐的手工工作。数据科学家必须投入无数小时进行广泛的超参数优化和特定领域的特征工程,才能从原始数据中提取出可靠的信息。

另一方面,更广泛的机器学习领域的最新进展——特别是大语言模型(LLM)的发展——已经改变了我们处理新任务的方式。LLM 通过上下文学习(ICL)展示了零样本预测的卓越能力。这种技术允许预训练模型通过提供输入上下文中的示例和指令来学习新任务,而无需更新任何底层模型权重。

今天,我们推出 TabFM,一个专门为表格数据分类和回归设计的基础模型。通过将表格预测构建为 ICL 问题,TabFM 消除了手动模型训练、超参数调优和复杂特征工程的需求。我们很高兴分享这种方法如何让用户通过单次前向传播即可在从未见过的表格上生成高质量的预测。TabFM 现已在我们 Hugging FaceGitHub 仓库中提供。

1、工作原理

传统的机器学习范式依赖于更新特定于给定数据集分布的模型参数。相比之下,ICL 范式完全绕过了这一点。TabFM 不需要为每个新任务进行传统的训练阶段,而是将整个数据集——包括历史训练样本和目标测试行——作为一个统一的提示输入。模型在推理时直接从该上下文中学习解释列和行之间的关系。

然而,将 ICL 应用于表格数据并不像对自然语言进行分词那样简单。标准语言模型处理一维有序序列,但表格本质上是二维且无序的:交换两行或两列不会改变数据的底层含义。为了有效处理这些多样化的表格结构,同时实现可扩展的零样本预测,TabFM 综合了 TabPFNTabICL 等架构的优势,形成了一种新颖的混合设计。如下所示的架构依赖于三个关键机制:

  • 交替行和列注意力:首先,原始表格通过多层注意力模块处理。类似于 TabPFN,这一步骤对列(特征)和行(样本)交替应用注意力。通过持续关注这两个维度,模型学习了丰富的表示,自然地捕获了复杂的特征交互和依赖关系。这种深度上下文化有效地完成了数据科学家原本需要繁琐手工特征构建的繁重工作。
  • 行压缩:在此上下文化之后,每个单独行的丰富跨注意力信息被压缩为单一的密集向量表示。
  • 上下文学习(ICL):最后,一个专用的 Transformer 对这一压缩嵌入序列进行操作。采用 TabICL 的高效方法,对这些压缩的行向量(而非原始未压缩的网格)进行注意力计算,大幅降低了计算成本。这确保了即使对于更大的数据集,预测步骤仍然保持极高的计算效率。
TabFM 模型架构。

2、大规模合成数据训练

构建基础模型的典型方法是使用高容量神经网络在大量多样化数据上进行训练。然而,表格 ML 的一个主要障碍是高质量、多样化的表格数据集——尤其是反映真实工业数据分析所需的大规模表格——在开源空间中极度稀缺。工业表格通常包含专有的模式结构和敏感信息,使其无法用于广泛的预训练。

由于合成表格可以生成到任意大的规模,它们实际上是在此规模上预训练基础模型的唯一可行选择。因此,TabFM 完全在数亿个合成数据集上进行训练。这些数据集使用结构因果模型(SCM)动态生成,其中包含各种随机函数。这种大规模合成生成捕获了现实世界表格数据中普遍存在的各种分布和复杂特征关系。因此,该模型能够很好地泛化到未见过的真实表格,正如我们在下面的基准测试中所展示的那样。

3、性能与基准测试

为了严格测试 TabFM 与现有最先进方法的对比,我们在 TabArena 上进行了评估,这是一个基于头对头胜率计算 Elo 分数的动态基准系统。这一全面评估涵盖了 38 个分类数据集和 13 个回归数据集,样本量从 700 到 150,000 不等。

如下面的性能图所示,我们对模型的两个不同配置进行了基准测试:

  • TabFM:代表模型的即用能力。预测通过单次前向传播生成,无需调优或交叉验证。
  • TabFM-集成:此配置通过引入交叉特征和 SVD(奇异值分解)特征进一步提升性能。我们使用非负最小二乘求解器计算 32 路集成的最优权重。对于分类任务,此变体还加入 Platt 缩放作为额外的校准步骤。

有关全面的 TabArena 基准结果——包括详细的逐折指标和与特定基线模型的头对头胜率——请访问我们的 GitHub 页面

TabArena 分类(上)和回归(下)中前 10 名模型的 Elo 评分(↑)。 (D) = 默认; (T+E) = 调优 + 集成。分数越高表示性能越好。

4、结束语

通过将表格预测重新构建为上下文学习问题,TabFM 利用混合注意力架构和海量合成训练数据,天然地捕获了复杂的特征交互。这种方法成功消除了手动特征工程、超参数优化和重复模型训练的传统瓶颈,并且持续优于经过大量调优的行业标准监督算法。TabFM 将现代基础模型的即用便利性直接带入了表格 ML 工作流程,使从业者能够通过单次前向传播生成高度准确的预测。

为了方便使用,TabFM 正在被直接集成到 Google BigQuery 中。在接下来的几周内,用户将能够使用简单的 AI.PREDICT SQL 命令在 BigQuery 中执行高级回归和分类——无需任何 ML 专业知识。


原文链接:Introducing TabFM: A zero-shot foundation model for tabular data

汇智网翻译整理,转载请标明出处