重新实现Pix2Seq的思考
几个月前,我发现自己又陷入了一场关于目标检测的讨论,充满了特征金字塔网络、区域提议网络和非极大值抑制块。作为一名在各种目标检测项目中投入大量时间的人,我感到一种熟悉的感觉:这本应是一个概念上简单的任务,却如此复杂。
然后我想起了谷歌研究院的一篇论文,我之前浏览过但从未完全消化:“Pix2Seq: 用于目标检测的语言建模框架”。 其核心前提似乎极其简单:如果我们把目标检测当作语言建模来处理呢?不需要锚点、不需要NMS、不需要复杂的损失函数——只需一个模型查看图像并以结构化的方式输出预测。
Pix2Seq论文介绍了一种新的目标检测方法,其中:
- 不使用专门的架构,而是将检测视为语言建模任务,并使用通用的编码器-解码器Transformer
- 边界框和类别标签表示为一系列标记
- 模型学习自回归地生成这些序列
坦白说——我持怀疑态度,但很感兴趣。我也很好奇尝试从头开始重建它会学到什么;这个过程确实改变了我对计算机视觉、自然语言处理和任何序列建模之间关系的看法。与其创建一个完全忠实的复现,我想探索核心思想在现代技术和我的工程偏好下如何工作。
这篇博文是我分享这些经验的尝试——令人惊讶的见解、优雅的解决方案,以及我不得不后退一步欣赏原始论文和它所开启的可能性的时刻。这不是一篇深入探讨Transformer架构或全面文献综述的理论文章。相反,我想分享我在实现这个模型时获得的实践经验。
实现框架选择:虽然原始的Pix2Seq论文使用了TensorFlow,但我选择用PyTorch实现所有内容,因为我更喜欢PyTorch;这被证明是一个好决定,因为PyTorch在处理复杂的分词和生成逻辑时调试体验更直观!
这个选择需要将一些TensorFlow特定的概念(比如他们的数据管道和损失计算)转换为PyTorch的惯用方式,但论文中的核心算法洞察在不同框架中可以无缝迁移。数学概念——量化、序列增强和约束生成——与框架无关。
为了从这篇博客文章中获得最大收益,我强烈建议配合阅读配套仓库;先阅读该部分,然后去查看代码。本文中的一些代码片段为了便于解释而简化,请查看仓库以找到工作的实现,以及重现本文输出所需的所有代码。
1、目标检测中的复杂性危机
在深入探讨Pix2Seq的优雅之前,值得欣赏我们试图逃避的东西。如果你以前从事过目标检测,你会知道现代系统是工程奇迹——但它们也是累积复杂性的纪念碑。
考虑一下典型的两阶段检测器如Faster R-CNN的情况:
- 通过主干网络进行特征提取
- 在多个尺度和宽高比上生成锚点
- 区域提议网络建议潜在的对象位置
- RoI对齐提取固定大小的特征
- 双头预测进行分类和边界框回归
- 复杂的损失函数平衡分类、定位和对象性
- 非极大值抑制消除重复检测
在之前的博客文章中,我和一位同事重新实现了Yolov7,我们花了几个小时理解损失函数是如何工作的,我们的写作风格跨越了多页。
尽管基于Transformer的方法如DETR最初看起来很有希望——消除了部分架构复杂性,并直接预测对象集合而无需生成锚点或NMS——但它引入了自己的挑战,即匈牙利匹配算法用于训练。在每次前向传递中,DETR需要解决一个分配问题,以最优方式匹配预测对象和真实对象,最小化一个平衡分类、定位和对象性分数的组合损失函数。这种二分匹配过程虽然数学上优雅,但增加了实现的复杂性,并且当训练没有按预期收敛时,调试起来可能很困难。因此,尽管DETR代表了向更原则性目标检测的真实进步,但它仍然需要大量的任务特定工程——只是在管道的不同部分。
每个组件都经过多年的研发才达到完美。损失函数本身就是艺术品——精心平衡多个目标的手动调整权重。然而,当时在我们实施过程中让我印象深刻的是:所有这些复杂性存在是因为我们在对抗一个根本性的不匹配:我们想要(对象列表)和我们得到它的方式(空间网格上的密集预测)。
如果我们能完全消除这种不匹配呢?
2、Pix2Seq革命:图像作为标记序列
与其通过日益复杂的空间推理来预测对象,我们为什么不教模型读取图像并用结构化语言描述它所看到的内容呢?Pix2Seq论文的主要洞察可以归纳为这一点——如果我们把图像、边界框和标签表示为一串标记,并训练模型直接输出这些标记。
这种转换在概念上很简单:
传统检测之前:
图像 → 复杂架构 → 密集预测 → NMS → 对象列表
之后(Pix2Seq):
图像 → 简单编码器-解码器 → 标记序列 → 对象列表
然而,起初让我困惑的是如何将连续的、无序的边界框世界转化为离散的、有序的语言世界?让我们探讨一下这是如何实现的。
3、将坐标转换为标记
我遇到的第一个挑战是整个方法中最基本的问题:如何将边界框及其相关类表示为标记?这不像文本分词那样有自然的单词边界。我们处理的是连续坐标,可以取几乎无限的值,我们需要将它们映射到离散符号。
让我们通过一个具体的例子来说明这一点。考虑一个坐标为[0.125, 0.333, 0.875, 0.667]
(XYXY格式,归一化到[0,1]范围)的边界框。
论文的方法非常简单:将坐标空间[0,1]划分为固定数量的均匀区间,然后将每个坐标四舍五入到最近的区间。我们可以这样实现:
class TokenProcessor:
def quantize(self, boxes: torch.Tensor) -> torch.Tensor:
# 将坐标缩放到量化范围 [0, bins-1]
# 并四舍五入到最近的整数(实际的量化步骤)
boxes = torch.round(boxes * (self.quantization_bins - 1))
# 限制到有效范围(处理任何边缘情况)
boxes = torch.clamp(boxes, 0, self.quantization_bins - 1)
# 移动到坐标词汇表范围
boxes = boxes + self.coord_vocab_shift
return boxes.long()
现在,让我们通过一个具体的例子来跟踪1000个量化区间的步骤:
# 原始归一化坐标
original_coord = 0.125
# 步骤1:缩放到 [0, 999] 范围
scaled = 0.125 * (1000 - 1) = 0.125 * 999 = 124.875
# 步骤2:四舍五入到最近的整数
quantized = round(124.875) = 125
# 步骤3:添加词汇表偏移量(例如,coord_vocab_shift = 1000)
token = 125 + 1000 = 1125
所以连续坐标 0.125
变成了离散的标记 1125
。
量化精度取决于区间的数量相对于图像分辨率。让我们分析一下:
def analyze_quantization_precision(image_size=640, quantization_bins=1000):
"""计算坐标量化的精度。"""
# 每个量化区间的像素数
pixels_per_bin = image_size / quantization_bins
print(f"每个区间的像素数: {pixels_per_bin:.2f}")
# 最大量化误差(以像素为单位)
max_error_pixels = pixels_per_bin / 2
print(f"最大量化误差: {max_error_pixels:.2f} 像素")
# 作为图像尺寸的百分比
error_percentage = (max_error_pixels / image_size) * 100
print(f"最大误差作为图像的百分比: {error_percentage:.3f}%")
analyze_quantization_precision(image_size=640, quantization_bins=1000)
# 每个区间的像素数: 0.64
# 最大量化误差: 0.32 像素
# 最大误差作为图像的百分比: 0.050%
在640×640图像上使用1000个区间,我们获得了亚像素精度!论文显示即使500个区间(每个区间1.28像素)在实践中也表现良好。
在推理过程中,我们需要将标记转换回坐标:
class TokenProcessor:
def dequantize(self, tokens: torch.Tensor) -> torch.Tensor:
# 移除坐标词汇表偏移
tokens = tokens - self.coord_vocab_shift
# 缩放回 [0, 1] 范围
tokens = torch.clamp(tokens, 0, self.quantization_bins - 1)
# 转换回 [0,1] 归一化坐标
return tokens.float() / (self.quantization_bins - 1)
示例反量化:
# 模型输出的标记
token = 1125
# 移除词汇表偏移
quantized = 1125 - 1000 = 125
# 转换为归一化坐标
normalized = 125 / (1000 - 1) = 125 / 999 = 0.1251
与原始值 0.125
相比,约有 0.0001
的差异,这是此过程引入的量化误差。
这种量化方案简单但效果显著。它将无限的可能边界框空间转换为离散、可学习的词汇表,同时保持准确目标检测所需的精度。现在,让我们看看如何定义我们的标记词汇表。
不同量化级别对边界框放置的影响如下所示。在这种情况下,图像是480x640的。
4、定义我们的词汇表
词汇表设计是Pix2Seq优雅之处的另一个领域,因为它有意的结构。在原始论文中,作者为特殊标记、类别和坐标标记保留了固定的范围,无论使用场景如何。这是为了在扩展到多个数据集或任务时提供灵活性。
在我的实现中,为了最小化未使用的标记数量,我固定了显式的特殊标记,并根据类别数量和量化区间的数量动态定义其余的词汇表。我认为这使模型更容易理解和调试,但如果模型在下游进行微调,就需要额外的思考以确保类别正确映射。
我定义了模型的词汇表如下:
# 默认保留10个标记用于特殊标记
num_special_tokens = 10
BASE_VOCAB_SHIFT = num_special_tokens
# 特殊标记
PADDING_TOKEN = 0 # 用于变长序列
BOS_TOKEN = 1 # 序列开始
EOS_TOKEN = 2 # 序列结束
# 类别标记(例如,10-89对于80个COCO类别)
class_range = range(BASE_VOCAB_SHIFT, BASE_VOCAB_SHIFT + num_classes)
# 数据增强框的假类别标记
FAKE_CLASS_TOKEN = BASE_VOCAB_SHIFT + num_classes
coord_vocab_shift = FAKE_CLASS_TOKEN + 1
# 坐标标记(例如,91-1090对于1000个区间)
coord_range = range(coord_vocab_shift, coord_vocab_shift + quantization_bins)
请注意,量化选择直接影响模型的词汇表大小:
# COCO(80个类别,1000个区间)的词汇表分解
special_tokens = 10 # BOS, EOS, PAD等
class_tokens = 80 # 每个COCO类别一个
fake_class_token = 1 # 用于序列增强
coordinate_tokens = 1000 # 量化区间
total_vocab_size = special_tokens + class_tokens + fake_class_token + coordinate_tokens
# = 10 + 80 + 1 + 1000 = 1091 个标记
即使我们实现了亚像素精度,这与现代语言模型相比仍然非常小;甚至GPT-3使用了50,000+个标记。这为我们提供了足够的空间来增加量化区间的数量,如果我们想在更大的图像上进行训练的话。
此时,你可能会想知道FAKE_CLASS_TOKEN
到底有什么用途——我们将在讨论数据增强时详细说明。
5、序列构造
现在我们有了词汇表,下一个挑战是将对象集转换为序列。与文本不同,图像中的对象没有自然顺序,盒子的顺序不应该重要。我实现的解决方案遵循论文的方法:每个对象变成一个5个标记的序列 [y_min, x_min, y_max, x_max, class]
,多个对象只是简单地连接在一起。在训练期间,标注是随机排序的。
像许多其他Google论文一样,论文使用 y_min
作为第一个坐标,而不是通常预期的 x_min
。我更倾向于遵循标准惯例,并在整个代码库中一致使用 [x_min, y_min, x_max, y_max, class]
,所以我确保这种转换只发生在编码和解码的分词器中。
下面呈现了这种逻辑的一个简化版本:
class TokenProcessor:
def build_sequences(self, boxes: torch.Tensor, labels: torch.Tensor):
"""将边界框转换为标记序列。"""
batch_size, num_boxes = boxes.shape[:2]
# 将XYXY转换为YXYX格式(论文的惯例)
boxes = boxes[..., [1, 0, 3, 2]]
# 将坐标量化为离散标记
boxes = self.quantize(boxes) # [B,N,4]
# 为类别标签添加词汇表偏移
target_labels = labels + self.BASE_VOCAB_SHIFT
# 为每个对象创建5个标记的序列
target_seq = torch.cat([boxes, target_labels.unsqueeze(-1)], dim=-1)
# 展平为每张图像的单个序列
target_seq = target_seq.reshape(batch_size, -1)
# 添加序列边界标记
target_seq = self._add_boundary_tokens(target_seq)
return target_seq
代码库中的逻辑还处理了诸如包含填充标记和确保正确的边界框范围等细节,这里为清晰起见省略了。
这种表示的美妙之处在于,它将“找到所有对象”的复杂问题转化为熟悉的“生成下一个标记”的问题。每一种语言建模技术——注意力机制、教师强制、KV缓存——突然适用于目标检测。
6、数据增强
到目前为止,数据增强在训练深度学习模型时的好处已经广为人知;通过增加训练数据的多样性来提高鲁棒性,而不收集额外的真实世界样本。然而,由于Pix2Seq是一个自回归模型——逐个标记生成序列——这为我们的增强策略带来了一些额外的考虑。
自回归Transformer模型通常使用一种称为教师强制的方法进行训练——在训练期间提供正确的前一个标记作为输入,以实现并行计算并防止错误积累,从而使Transformer训练更快更稳定——Pix2Seq也不例外。在使用教师强制训练时,模型只看到完美的地面真实序列。但在推理时,它必须逐步生成序列,可能会早期出错,这些错误会在生成过程中累积。
这创造了所谓的暴露偏差——模型从未被训练来纠正自己的错误。在目标检测中,这表现为:
- 提前终止:模型在找到所有对象之前停止生成
- 较差的重复处理:模型生成冗余的检测
- 级联错误:一个错误的坐标预测会打乱整个框
因此,除了我们在目标检测中通常期望的常用方法外,我们还必须考虑序列增强;教模型处理不完美、噪声的序列——使其在生成预测时更加稳健。
Pix2Seq采用的数据增强方法可以总结为三个主要类别:
- 图像增强:
- 框增强
- 序列增强
让我们逐一探讨这些。
6.1 图像增强
尽管论文中并未详细讨论,但检查代码发现,图像增强的方法相当标准;没有使用复杂的技巧如Mixup和Mosaic。我没有试图精确复制作者使用的超参数,而是创建了一个遵循其方法核心思想的图像增强管道,包括翻转、缩放抖动、裁剪和颜色增强。
一些使用的图像增强示例如下:
6.2 框增强
虽然图像增强的方法相当标准,但框增强的方法才是有趣的开始。在深入具体增强之前,让我们首先回顾一下它们为何需要。
目标检测模型需要对各种挑战具有鲁棒性:
- 对象定位的小变化
- 错误检测
- 重叠和重复检测
因此,我们希望教会我们的模型:
- 准确的定位:精确预测框坐标
- 正确的分类:正确识别对象类别
- 错误检测拒绝:避免检测不存在的对象
- 重复抑制:避免对同一对象的多次检测
Pix2Seq通过全面的框增强策略解决了这些挑战,创建了正负训练示例。
现在,让我们探讨具体的增强方法。
框抖动
框抖动为现有框添加小的随机扰动,帮助模型应对轻微的定位变化:
- 创建真实对象的轻微变化
- 教会模型应对小坐标差异
- 保持原始类别标签,因为这些是有效的示例
在这里,我使用了极端值来说明效果;代码库从截断正态分布中采样以保持变化小而现实。当然,这应该根据您在领域中可能观察到的变化来设置!
移动真实框以创建硬负样本
框移动将框移动到新位置,同时保留其大小。这有助于模型学习避免错误位置的假阳性检测:
- 使用真实对象形状但将其移动到错误位置
- 创建看似合理但不正确的检测
- 帮助模型学习空间上下文
- 用假类别标记表示“无效检测”
6.3 随机框生成
生成完全随机的框有助于模型学习拒绝任意框提案:
- 创建多样化的负样本
- 帮助模型学习拒绝任意框提案
- 用假类别标记
- 使用正态分布生成大小以创造多样性
6.4 序列增强
虽然图像和框增强解决了空间鲁棒性问题,但序列增强则处理了自回归生成的独特挑战。这就是Pix2Seq的方法变得特别创新的地方,因为它直接解决了教师强制固有的曝光偏差问题。
Pix2Seq的序列增强策略基于三个关键原则:
- 标签污染:在训练期间故意提供错误的类别标签
- 噪声标记集成:教模型处理和生成“假”检测
- 随机排序:打乱序列中对象的顺序
让我们详细探讨每一个。
标签污染策略
标签污染机制通过一个巧妙的两阶段概率过程实现,如corrupt_class_labels
方法所示。让我们检查代码的关键部分以了解它是如何工作的:
class TokenProcessor
def corrupt_class_labels(
self, labels: torch.Tensor, padding_mask: torch.Tensor
) -> torch.Tensor:
"""根据指定策略污染类别标签。
对于所有策略,我们首先决定是否保留原始标签(50%的概率)。
然后对于我们要污染的标签,应用策略的噪声类型:
- NONE: 保持所有标签不变
- RANDOM: 替换为随机的有效类别
- RANDOM_AND_FAKE: 随机类别和假标记的均等分割
参数:
labels: 类别标签 [B,N]
padding_mask: 布尔掩码,True表示填充 [B,N]
"""
# 对于 'none' 策略或如果禁用污染,返回原始标签
if (
self._corruption_strategy == LabelCorruptionStrategy.NONE
or not self._corrupt_class_labels
):
return labels
batch_size, num_labels = labels.shape
valid_tokens = ~padding_mask
# 首先决定保留哪些有效标记(50%的概率)
keep_mask = (
torch.rand(batch_size, num_labels, device=labels.device) < 0.5
) & valid_tokens
# 从原始标签开始
corrupted = labels.clone()
# 为污染创建随机类别标签
rand_cls = torch.randint(
self.BASE_VOCAB_SHIFT, self.BASE_VOCAB_SHIFT + self.num_classes, (batch_size, num_labels), device=labels.device
)
if self._corruption_strategy == LabelCorruptionStrategy.RANDOM:
# 对于我们不保留的标记,替换为随机类别
corrupted = torch.where(valid_tokens & ~keep_mask, rand_cls, corrupted)
elif self._corruption_strategy == LabelCorruptionStrategy.RANDOM_AND_FAKE:
# 对于我们不保留的标记,决定是随机还是假
noise_mask = torch.rand(batch_size, num_labels, device=labels.device) < 0.5
tokens_to_corrupt = valid_tokens & ~keep_mask
# 在 noise_mask 为 True 时应用随机类别
corrupted = torch.where(tokens_to_corrupt & noise_mask, rand_cls, corrupted)
# 在 noise_mask 为 False 时应用假标记
fake_cls = torch.full_like(
labels, self.FAKE_CLASS_TOKEN, device=labels.device
)
corrupted = torch.where(
tokens_to_corrupt & ~noise_mask, fake_cls, corrupted
)
return corrupted
从这里可以看出,算法首先随机决定保留50%的有效对象。这确保了模型在训练期间始终看到一些正确的标签,防止完全混淆。对于剩余的50%的对象,策略决定了应用哪种类型的噪声:
我实现的策略有:
RANDOM
: 替换为随机的有效类别ID(COCO的0-79)RANDOM_AND_FAKE
: 随机类别和假标记的均等分割NONE
: 保持所有标签不变
让我们通过运行我们的图像和标签通过分词器几次来看看实际情况。
# 创建具有不同污染策略的分词器
tokenizer_none = TokenProcessor(
quantization_bins=1000, noise_bbox_weight=1.0, eos_token_weight=1.0,
max_seq_len=500, num_classes=80,
corrupt_class_labels=False,
corruption_strategy=LabelCorruptionStrategy.NONE,
verbose=False
)
tokenizer_random = TokenProcessor(
quantization_bins=1000, noise_bbox_weight=1.0, eos_token_weight=1.0,
max_seq_len=500, num_classes=80,
corrupt_class_labels=True,
corruption_strategy=LabelCorruptionStrategy.RANDOM,
verbose=False
)
tokenizer_mixed = TokenProcessor(
quantization_bins=1000, noise_bbox_weight=1.0, eos_token_weight=1.0,
max_seq_len=500, num_classes=80,
corrupt_class_labels=True,
corruption_strategy=LabelCorruptionStrategy.RANDOM_AND_FAKE,
verbose=False
)
print(f"使用COCO图像 {image_id} 有 {len(class_ids)} 个对象")
print(f"对象类别: {class_ids}")
print(f"类别名称: {[category_names[cid]['name'] for cid in class_ids]}")
# 应用图像增强以获取归一化框(如实际管道中所做的)
augmented_image, augmented_boxes, augmented_labels, unpadded_size = eval_augmentor(
image, xyxy_bboxes, class_ids, normalize_boxes=True
)
# 转换为张量
aug_boxes_tensor = torch.tensor(augmented_boxes, dtype=torch.float32).unsqueeze(0)
aug_labels_tensor = torch.tensor(augmented_labels, dtype=torch.long).unsqueeze(0)
# 提取类别标记的帮助函数
def extract_class_tokens(sequence, tokenizer, num_objects):
class_positions = [5 + i*5 for i in range(num_objects)]
return [(sequence[0][pos] - tokenizer.BASE_VOCAB_SHIFT).item() for pos in class_positions]
# 显示不同试验中的污染变化
num_trials = 5
print(f"\n{num_trials} 次试验中的污染变化:")
corruption_results = []
fake_class_id = tokenizer_mixed.FAKE_CLASS_TOKEN - tokenizer_mixed.BASE_VOCAB_SHIFT
for trial in range(num_trials):
# 生成具有不同污染策略的序列
input_seq_none, target_seq_none, weights_none = tokenizer_none.build_sequences(aug_boxes_tensor, aug_labels_tensor)
input_seq_random, target_seq_random, weights_random = tokenizer_random.build_sequences(aug_boxes_tensor, aug_labels_tensor)
input_seq_mixed, target_seq_mixed, weights_mixed = tokenizer_mixed.build_sequences(aug_boxes_tensor, aug_labels_tensor)
num_objects = len(augmented_labels)
# 提取类别标记
input_classes_random = extract_class_tokens(input_seq_random, tokenizer_random, num_objects)
input_classes_mixed = extract_class_tokens(input_seq_mixed, tokenizer_mixed, num_objects)
target_classes = extract_class_tokens(target_seq_none, tokenizer_none, num_objects)
# 检查更改
random_changes = [f"{category_names[orig]['name']}→{category_names[new]['name'] if 0 <= new < 80 else 'INVALID'}"
for orig, new in zip(target_classes, input_classes_random) if orig != new]
mixed_changes = [f"{category_names[orig]['name']}→{'FAKE' if new == fake_class_id else category_names[new]['name'] if 0 <= new < 80 else 'INVALID'}"
for orig, new in zip(target_classes, input_classes_mixed) if orig != new]
random_result = ', '.join(random_changes) if random_changes else "none"
mixed_result = ', '.join(mixed_changes) if mixed_changes else "none"
print(f"试验 {trial+1}: Random({random_result}) Mixed({mixed_result})")
corruption_results.append((len(random_changes), len(mixed_changes)))
# 汇总统计
total_random = sum(r[0] for r in corruption_results)
total_mixed = sum(r[1] for r in corruption_results)
print(f"汇总: Random 被污染 {total_random}/{num_trials * num_objects}, Mixed 被污染 {total_mixed}/{num_trials * num_objects}")
使用COCO图像 382088 有 1 个对象
对象类别: [17]
类别名称: ['马']
污染变化在 5 次试验中:
试验 1: Random(none) Mixed(none)
试验 2: Random(马→瓶子) Mixed(马→苹果)
试验 3: Random(马→椅子) Mixed(马→卡车)
试验 4: Random(马→键盘) Mixed(none)
试验 5: Random(马→消防栓) Mixed(马→狗)
汇总: Random 被污染 4/5, Mixed 被污染 3/5
假类别标记
假类别标记是Pix2Seq设计中的一个巧妙之处。它弥合了框增强(生成噪声框)和序列生成(必须学会处理虚假检测)之间的差距。
假标记在框增强过程中被分配给噪声框,在序列增强过程中,真实对象也可以被污染为假类别。这创造了一个统一的框架,让模型学会生成和处理“非对象”检测。
将我们之前看到的图像通过这个过程,我们会得到以下输出:
BBoxAugmentation 初始化为:
- num_classes: 80
- 将生成假标签为: 80
预期序列长度: 32
Tokenizer max_seq_len: 500
框增强结果:
总对象数: 6
真实对象: 1
假对象: 5
假对象位置: tensor([0, 1, 2, 4, 5])
Tokenizer 词汇表:
真实类别范围: 10 到 89
假类别标记: 90
坐标标记开始: 91
带噪声框的序列:
[[1,
837, 279, 977, 946, 90,
310, 451, 696, 1008, 90,
261, 217, 647, 774, 90,
347, 380, 904, 766, 27,
355, 91, 741, 391, 90,
556, 424, 1090, 1090, 90,
2, 0, 0, 0, 0]]
位置 0: FAKE_TOKEN
位置 1: FAKE_TOKEN
位置 2: FAKE_TOKEN
位置 3: 马
位置 4: FAKE_TOKEN
位置 5: FAKE_TOKEN
随机对象排序
随机对象排序解决了一个微妙但重要的问题:在自回归生成中,模型不应学习基于序列中对象出现顺序的虚假相关性。这种实现发生在数据整理器中,而不是分词器本身,通过一个简单但有效的排列:
# 从 Pix2SeqCollator.__call__ 方法
if self._is_training:
idx = torch.randperm(num_boxes, device=boxes.device)
boxes = boxes[idx]
labels = labels[idx]
这个简单的操作有深远的影响。如果没有随机排序,模型可能会学习“人们通常在序列中首先出现”或“汽车通常在建筑之后被检测到。”这样的模式会损害泛化能力,并使模型对不同的对象分布变得脆弱。
关键的见解是,这种随机化发生在序列生成之前,确保模型学习处理任何顺序的对象。在推理时,模型会遇到自然检测到的任何顺序的对象,这种训练过程确保它可以处理任何安排。
7、综合起来
现在我们已经看到了关键组成部分,让我们将所有这些步骤放在一起,以可视化这可能是什么样子。
第一次检查图像输出时,我对论文方法推荐的假框数量感到惊讶,但看起来这在实践中确实有效!
这种数据增强方法将自回归模型的潜在弱点(对早期错误的敏感性)转化为优势,创建了一个概念上优雅且实践上稳健的系统;模型学会了预测FAKE
类标记来处理噪声对象,同时预测真实对象的正确坐标和类别。这创造了一个既是生成器(对于真实对象)又是鉴别器(对于假对象)的模型。
8、训练流程
现在我们已经探讨了Pix2Seq如何通过数据增强创建丰富的训练信号——生成带有小扰动的真实对象、作为硬负样本的移动框、作为容易负样本的随机框,以及复杂的序列增强策略——我们需要了解模型实际上如何从这种真实和合成数据的混合中学习。
好消息是,整体训练过程相当通用,应该对任何训练过语言模型的人都很熟悉。关键的见解是,不是所有标记都应该对损失做出相等的贡献。我们刚刚创建的假框有一个特定的目的:教模型识别和拒绝无效检测。但这需要调整我们的训练目标,以便适当处理不同类型的标记;这是通过精心设计的标记权重方案实现的。
我们数据流水线中的大部分复杂性——处理变长序列、数据增强和标记化——发生在我们将数据整理成批次时:
class Pix2SeqCollator:
def __call__(self, batch):
"""将一批图像/框转换为模型输入。"""
# 标准图像处理
images = torch.stack([x["image"] for x in batch])
# 变长框序列需要填充
max_boxes = max(x["num_boxes"] for x in batch)
# 在训练时应用框增强
if self.is_training:
augmented_boxes, augmented_labels = [], []
for item in batch:
boxes, labels = self.bbox_augmentor.augment_bbox(
item["boxes"], item["labels"],
n_noise_bbox=max_boxes - item["num_boxes"]
)
augmented_boxes.append(boxes)
augmented_labels.append(labels)
# 转换为标记序列
input_seq, target_seq, token_weights = self.token_processor.build_sequences(
boxes=torch.stack(augmented_boxes),
labels=torch.stack(augmented_labels)
)
return {
"image": images,
"input_seq": input_seq,
"target_seq": target_seq,
"token_weights": token_weights
}
在此过程中,我们在构建序列时计算我们的标记权重,如下所示:
# 计算对象的标记权重
# 通过将它们的标签与假类别标记进行比较来检查哪些对象是假的/噪声对象
is_fake = target_labels == self.FAKE_CLASS_TOKEN
# 计算边界框坐标标记的权重(每个对象4个标记:y_min, x_min, y_max, x_max)
# 假对象应该学习预测“假”类别但不学习坐标
bbox_weights = torch.where(
is_padding.unsqueeze(-1), # 扩展填充掩码以匹配框维度 [B,N,1] -> [B,N,4]
torch.zeros_like(boxes, dtype=torch.float32), # 填充标记的权重为0(在损失中忽略)
torch.where(
is_fake.unsqueeze(-1), # 扩展假掩码以匹配框维度 [B,N,1] -> [B,N,4]
torch.full_like(boxes, self.noise_bbox_weight, dtype=torch.float32), # 假对象对于坐标有指定的权重
torch.ones_like(boxes, dtype=torch.float32), torch.ones_like(boxes, dtype=torch.float32), # 真实对象对于坐标有权重1.0(完全学习坐标)
),
)
# 计算类别标记的权重(每个对象1个标记)
# 假对象仍然对于类别标记有权重1.0,以便模型学会预测“假”
label_weights = torch.where(
is_padding, # 检查这是否是填充标记
torch.zeros_like(labels, dtype=torch.float32), # 填充标记的权重为0(被忽略)
torch.ones_like(
labels, dtype=torch.float32
), # 真实和假对象对于类别预测都有权重1.0
)
# 将坐标和类别权重合并为一个权重张量
# 每个对象有5个标记:[y_min, x_min, y_max, x_max, class]
token_weights = torch.cat(
[bbox_weights, label_weights.unsqueeze(-1)], dim=-1
) # [B, N, 5] 其中5 = 4个坐标标记 + 1个类别标记
这应该能让模型学习:
- 对于真实对象(来自真实数据+抖动):正确学习坐标和类别(所有权重=1.0)
- 对于假对象(来自移动/随机生成):学习将它们分类为“假”,但不要浪费精力学习坐标精度(坐标权重=0.0,类别权重=1.0)
- 对于填充:完全忽略(所有权重=0.0)
在我实现的过程中,我一开始感到困惑:如果这些是假对象,我们不应该完全忽略它们吗?然而,仔细思考后,理由就变得清晰了;模型必须学会主动拒绝无效对象,通过预测FAKE标记,而不是简单地忽略它们。这创造了一个在推理时能够区分有效对象和噪声的判别模型。当模型在生成过程中遇到一个虚假检测时,它已经学会了将其分类为假,而不是感到困惑。
这导致了模型的行为,即学会识别非对象预测(通过预测FAKE
类别),而不会浪费容量学习合成噪声的具体坐标。这是直接嵌入到序列建模框架中的对抗训练。
9、损失函数
有了我们的标记权重,实际的损失函数变得简单:
class Pix2SeqTrainer
def _calculate_loss(self, logits, target_seq, token_weights):
"""标准的语言建模损失,带有标记权重。"""
# 简单的交叉熵 - 与语言建模相同
loss = F.cross_entropy(logits, target_seq, reduction='none')
# 应用我们精心设计的权重
weighted_loss = loss * token_weights
# 按有效标记归一化
num_valid = (target_seq != self.PADDING_TOKEN).sum().clamp(min=1)
return weighted_loss.sum() / num_valid
其优雅令人惊叹。几十年的目标检测研究产生了越来越复杂的损失函数——焦点损失用于类别不平衡,IoU损失用于定位,复杂的匹配算法用于分配。然而,通过这种方法,复杂性从损失函数转移到了数据表示中;这是一种深刻不同的设计哲学,利用了序列建模的力量,而不是与之对抗。因此,我们只剩下用于训练语言或分类模型的相同交叉熵损失。
10、探索架构
现在我们已经拆解了我们的损失函数,让我们看看模型架构本身。如果我们回忆一下标准的Transformer架构——我们会认识到这是一个相当直接的实现;没有任何专用组件和没有任务特定的头部——只是一个干净的编码器-解码器Transformer,将目标检测视为序列到序列的问题。
让我们了解一下关键组件是什么:
它看起来可能有点复杂,所以让我们分解一下简化前向传递的高层次信息流:
class Pix2seqModel(nn.Module):
def forward(self, images, tgt, tgt_padding_mask=None):
encoded, _ = self.encode(images)
return self.decode(tgt, encoded, tgt_padding_mask=tgt_padding_mask)
正如我们所期望的,我们的图像输入经过编码步骤,然后将此输出与包含我们预测的序列一起传递给解码步骤。让我们依次分解这些,从encode
开始:
class Pix2seqModel(nn.Module):
def encode(self, images):
"""将图像处理为上下文化的补丁表示。"""
# 使用预训练的ViT提取补丁特征
features = self.vit.forward_features(images) # [B, num_patches+1, embed_dim]
# 投影到模型维度
encoded = self.encoder_proj(features)
encoded = self.encoder_norm(encoded)
# 添加位置信息
pos_emb = self.pos_embed(encoded)
encoded = encoded + pos_emb
# 应用变压器层进行跨补丁推理
for encoder_block in self.transformer_encoder:
encoded = encoder_block(encoded)
return encoded
在这里,我们可以看到编码器相当标准——一个Vision Transformer加上额外的变压器层。当我第一次看到这个架构时,我质疑额外的编码器层(超出预训练的ViT)是否对性能是必要的。我进行了几次实验,移除了这些层,发现模型可以学习任务,但要慢得多。因此,我怀疑如果我们只是扩大视觉编码器,我们可以进一步简化这个架构。然而,这些实验成本高昂且耗时,所以我选择不再继续。
现在,让我们探索decode
方法,其中目标检测转化为序列建模:
def decode(self, tgt, encoder_input, use_cache=False):
"""自回归生成标记序列。"""
# 嵌入标记并添加学习的位置编码
decoded = self.token_embedding(tgt)
decoded = decoded + self.dec_pos_embed[:, :tgt.size(1)]
# 应用因果变压器层与交叉注意
for decoder_block in self.transformer_decoder:
decoded = decoder_block(
decoded, # 自注意生成的标记
encoder_input, # 交叉注意图像特征
use_cache=use_cache # KV缓存用于高效推理
)
decoded = self.decoder_norm(decoded)
return self.output_proj(decoded) # 投影到词汇表
在实现过程中让我印象深刻的是交叉注意机制如何自然地处理视觉-语言桥梁。每个生成的标记可以关注图像的相关部分,创建一个动态、内容感知的生成过程。与DETR的固定对象查询不同,这种方法让模型根据当前生成的内容决定关注什么。
请注意,这里,论文使用了学习的位置嵌入用于解码器序列,而不是正弦位置嵌入。
11、现代架构实验:超越原始论文
虽然我按照论文的架构实现了忠实的版本(标准的Transformer和正弦位置编码),但我也不得不尝试一些现代的Transformer组件。这导致了我的LlamaPix2Seq
变体,它结合了:
旋转位置嵌入(RoPE):而不是学习的位置嵌入,我为解码器序列位置和视觉编码器的空间位置实现了RoPE:```
class RoPEMultiHeadAttention(MultiHeadAttention):
def init(self, embedding_dim, num_heads,
q_max_seq_len=8192,kv_max_seq_len=None,
q_rope_base=500000.0, k_rope_base=None
):
# 查询和键/值序列的不同RoPE参数# 允许文本和视觉的不同上下文长度
super().init(embedding_dim, num_heads, is_causal, bias, dropout)
# 预计算不同序列类型的RoPE嵌入
q_cos, q_sin = precompute_rope_params(
head_dim=self.head_dim, context_length=q_max_seq_len,
theta_base=q_rope_base
)
if q_max_seq_len != kv_max_seq_len:
k_cos, k_sin = precompute_rope_params(
head_dim=self.head_dim, context_length=kv_max_seq_len,
theta_base=k_rope_base
)
else:
k_cos, k_sin = q_cos, q_sin
SwiGLU前馈网络: 用来自LLaMA的SwiGLU变体替换标准的MLP层:
class SwiGLUFFN(nn.Module):
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
RMSNorm: 使用RMS归一化而不是LayerNorm,以获得更好的训练稳定性。
这些修改部分是出于好奇心——我想看看现代语言建模的进展是否能转移到这个视觉-语言设置中。结果令人鼓舞——获得了类似的性能水平,但收敛速度比论文中的变体慢得多。我怀疑,由于这个架构的偏差更少,如果我们增加训练数据集的大小,它会表现得更好。
12、综合所有内容
现在我们了解了我们的数据管道、模型和损失函数,训练过程看起来与任何语言模型都非常相似:
class Pix2SeqTrainer(Trainer):
def calculate_train_batch_loss(self, batch):
images = batch["image"] # [B,3,H,W]
input_seq = batch["input_seq"] # [B,S] 带有噪声对象
target_seq = batch["target_seq"] # [B,S] 向前偏移一位
token_weights = batch["token_weights"] # [B,S] 权重方案
# 通过编码器-解码器进行前向传递
logits = self.model(images, input_seq)
# 重塑为标准的交叉熵损失
B, S, V = logits.shape
logits = logits.view(-1, V)
target_seq = target_seq.view(-1)
token_weights = token_weights.view(-1)
# 与任何自回归模型相同的损失函数
loss = F.cross_entropy(logits, target_seq, reduction='none')
weighted_loss = loss * token_weights
# 按照有效(非填充)标记的数量进行归一化
num_valid = (target_seq != self.PADDING_TOKEN).sum().clamp(min=1)
return weighted_loss.sum() / num_valid
在训练过程中,让我不断感到惊讶的是,一切感觉都如此“正常”。没有复杂的损失平衡,没有为不同头设置的学习率调度,也没有为特殊组件仔细初始化。这只是将序列建模应用于视觉任务;这感觉像是未来的方向!
13、推理:从慢到快
在推理过程中,我首先发现自回归生成非常慢。真的非常慢。对于每个检测到的对象,模型需要依次生成5个标记,而天真地,每个标记生成都需要通过视觉编码器处理整个图像。对于一张图片中有10个对象的情况,这就需要50次连续的前向传递!
13.1 速度问题
def naive_inference(self, images):
"""慢速推理 - 每次生成标记时都要重新编码图像。"""
batch_size = images.size(0)
sequences = torch.full((batch_size, 1), self.BOS_TOKEN)
for step in range(self.max_sequence_length):
# 这是瓶颈:每一步都要重新编码图像
encoder_output = self.encode(images) # 很昂贵!
logits = self.decode(sequences, encoder_output)
next_tokens = self.sample_next_token(logits[:, -1])
sequences = torch.cat([sequences, next_tokens], dim=1)
if all_sequences_ended(sequences):
break
return sequences
问题在于我们在每个标记生成步骤中都要重新编码相同的图像特征。这是浪费的——因为我们的图像特征在序列生成期间不会改变。
13.2 解决方案:KV缓存和编码器重用
解决方案涉及两个互补的优化,可以显著加快推理速度:
def fast_inference(self, images):
"""使用缓存计算的快速推理。"""
batch_size = images.size(0)
# 优化1:仅编码一次并重复使用
encoder_output = self.encode(images) # 只计算一次!
# 初始化解码器状态
sequences = torch.full((batch_size, 1), self.BOS_TOKEN)
for step in range(self.max_sequence_length):
# 优化2:使用缓存的解码器状态 - 仅处理新标记
logits = self.decode(
sequences[:, -1:], # 仅最后一个标记
encoder_output, # 重复使用的编码器输出
use_cache=True, # 启用KV缓存
)
# 应用约束并采样
allowed_tokens = self.get_allowed_tokens(step, sequences)
next_tokens = self.constrained_sample(logits, allowed_tokens)
sequences = torch.cat([sequences, next_tokens], dim=1)
if self.all_ended(sequences):
break
return sequences
让我们通过一个具体的例子来分解发生了什么。假设我们要生成一个20个标记的序列(4个对象 × 5个标记):
- 编码器重用节省了冗余计算:
# 天真:对同一张图像进行20次编码
step 1: encode(image) → decode([BOS])
step 2: encode(image) → decode([BOS, token1]) # 同样的图像!
step 3: encode(image) → decode([BOS, token1, token2]) # 同样的图像再次!
...
# 优化:仅编码一次,重复使用
encode(image) → 缓存结果
step 1: 使用缓存编码解码[BOS]
step 2: 使用缓存编码解码[BOS, token1]
step 3: 使用缓存编码解码[BOS, token1, token2]
这给了我们一个不错的常数加速,但真正的复杂性问题在于解码器。
- KV缓存是关键所在。在标准的Transformer解码中,每一步都需要重新处理整个序列:
# 没有KV缓存:O(S²)复杂度
step 1: 处理1个标记 ([BOS])
step 2: 处理2个标记 ([BOS, token1])
step 3: 处理3个标记 ([BOS, token1, token2])
...
step 20: 处理20个标记 ([BOS, token1, ..., token19])
# 总工作量: 1 + 2 + 3 + ... + 20 = 210 个标记处理操作
# 一般情况下: 1 + 2 + ... + S = S(S+1)/2 ≈ S²/2 = O(S²)
然而,有了KV缓存,我们可以存储之前步骤中的键和值表示:
# 有KV缓存:O(S)复杂度
step 1: 处理1个标记 ([BOS]) + 缓存其K,V状态
step 2: 处理1个标记 ([token1]) + 重用步骤1的K,V缓存
step 3: 处理1个标记 ([token2]) + 重用步骤1-2的K,V缓存
...
step 20: 处理1个标记 ([token19]) + 重用步骤1-19的K,V缓存
# 总工作量: 1 + 1 + 1 + ... + 1 = 20 个标记处理操作
# 一般情况下: S × 1 = O(S)
因此,我们可以总结这些优化的综合影响如下:
- 编码器重用: O(S) → O(1) 对图像处理(常数因子加速)
- KV缓存: O(S²) → O(S) 对序列处理(根本性的复杂度减少)
更重要的是,O(S) 而不是 O(S²) 的扩展意味着对于包含许多对象的图像,加速效果会更加显著。
结合这些优化,我的测试图像上的推理速度显著提高——从难以使用(每张图像几秒钟)变为实际可行(几百毫秒)。
13.3 约束生成:强制结构
在推理过程中,我很快意识到,并非所有的标记序列都作为对象描述有意义。生成 [y_min=50, x_min=30, y_max=20, x_max=70, class=car]
会产生一个无效的框,其中 y_max < y_min
。原始论文提到使用核采样进行生成,但没有详细说明结构约束,因此这成为我关键的实现补充。
我实现了基于序列位置的结构有效性约束系统:
class TokenMaskCache:
def get_allowed_tokens(self, pattern_pos: int, cur_seq: torch.Tensor):
"""根据序列位置约束有效的标记。"""
if pattern_pos == 0: # y_min 位置
# 可以生成坐标或结束序列
allowed = coordinate_tokens + [EOS_TOKEN]
elif pattern_pos == 1: # x_min 位置
# 仅允许坐标
allowed = coordinate_tokens
elif pattern_pos == 2: # y_max 位置
# 必须大于2步前的 y_min
y_min = cur_seq[:, -2] - self.coord_vocab_shift
allowed = coordinate_tokens[y_min+1:]
elif pattern_pos == 3: # x_max 位置
# 必须大于2步前的 x_min
x_min = cur_seq[:, -2] - self.coord_vocab_shift
allowed = coordinate_tokens[x_min+1:]
elif pattern_pos == 4: # class 位置
# 仅允许有效的类别标记
allowed = class_tokens
return allowed
这个约束系统确保了每个生成的序列对应一个有效的边界框。有趣的是,这些约束不会损害生成质量——它们实际上有助于防止模型浪费概率质量在不可能的序列上。
13.4 采样方法
对于实际的标记采样,我实现了核采样(top-p)结合结构约束:
def sample_next_token(self, logits, allowed_tokens):
"""使用核采样从约束分布中采样。"""
# 应用硬性结构约束
constrained_logits = logits.masked_fill(~allowed_tokens, float('-inf'))
# 应用核采样以获得多样性
if self.top_p > 0:
sorted_logits, sorted_indices = torch.sort(constrained_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除超过累积阈值的标记
indices_to_remove = cumulative_probs > self.top_p
indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
indices_to_remove[..., 0] = 0
# 应用掩码
for batch_idx in range(constrained_logits.size(0)):
batch_indices_to_remove = sorted_indices_to_remove[batch_idx].scatter(
0, sorted_indices[batch_idx], sorted_indices_to_remove[batch_idx]
)
constrained_logits[batch_idx][batch_indices_to_remove] = float('-inf')
# 从精炼的分布中采样
probs = F.softmax(constrained_logits / self.temperature, dim=-1)
return torch.multinomial(probs, num_samples=1)
结构约束和概率采样的相互作用创造了一个既有效又多样化的生成过程——正是我们想要的对象检测。
14、结束语
在我进行的训练运行中,我大致遵循了论文中概述的配方,值得注意的是,我使用了WSD学习率调度器,以便更容易从检查点继续训练;超参数的完整列表作为配置文件存储在代码库中。在4xA100 GPU上训练,当我从头开始在COCO上训练300个epoch,图像尺寸为640x640时,我能够再现论文中的结果,达到类似的级别(约44 AP)。
一些正确的预测示例如下:
然而,在分析我的训练结果时,我发现了一件起初看起来像我的评估代码中的错误的事情:中间模型检查点有时会比最终的“最佳”模型产生更好的视觉结果。让我们进一步探讨这一点,并思考为什么这可能是这种情况。
15、错误分析:分析中间检查点的预测
让我们看一下我的训练日志中的三个具体示例,这些示例展示了这种令人困惑行为的不同方面。这些模式在数十张图像和多个训练运行中重复出现。我最初以为是我的检查点加载存在错误,但在仔细验证后,这些模式是真实的且系统性的。
15.1 示例1:变成狗的羊
考虑以下图像 —— 包含一个明显可见的羊以及背景中一些难以识别的小物体 —— 以及模型在其训练周期中的不同时间点所做的预测。
我们可以看到,在第20个epoch时,模型正确地将羊标记为“sheep”,并识别出背景中的一个小物体为“car”(在仔细检查后是正确的)。它错过了我本人也无法识别的一些非常难看的标签(person, cell phone)。然而,从第90到300个epoch,它学会了将羊称为“dog”——尽管一开始它是正确的。
这尤其引人注目:模型学会了对它最初正确的东西变得错误。对我来说,这表明它不仅仅是学习对象识别,而是学习训练数据中特定的标注模式。
15.2 示例2:伞和靴子
现在,让我们看一个不同的例子,其中初始预测并不明确更好,但展示了一些有趣的改变。
在第10–20个epoch时,模型试图预测靴子周围的物体(标签错误),但完全忽略了伞。到第40个epoch时,模型完全没有预测任何东西——完全的检测失败。然而,从第50个epoch到第300个epoch,模型正确预测了伞(位置和标签),但完全忽略了靴子。
值得注意的是这一进程:模型从认识到有些东西存在(靴子)到彻底放弃,再到最终学习复制精确的标注模式——包括系统性地忽略明显可见的物体。这显示了模型不仅学习了要检测什么,还学习了基于训练模式不检测什么。
15.3 示例3:三明治分类混淆
在我们的最后一个例子中,我们有多个重叠的“三明治”标签(虽然它们看起来更像是涂抹酱料的饼干),加上几乎看不见的背景物体(键盘、遥控器、鼠标)。
在第120个epoch时,模型将所有三明治对象预测为“pizza”——考虑到它们的外观,这可能是一个合理的猜测。它正确识别了键盘和鼠标,但错过了遥控器。到第210个epoch时,模型表现出显著改进:它正确地将3个对象标记为“sandwich”,但将另外两个错误地标记为“donut”(也应该是“sandwich”)。它正确识别了键盘和鼠标,现在检测到了遥控器,但将其称为“cell phone”。虽然不完美,但这代表了更完整的检测。
但到了第300个epoch,模型的表现有所下降:它只正确预测了一个对象为“sandwich”,另一个为“donut”,似乎完全放弃了检测其余的对象。它仍然正确识别了键盘和鼠标,但第210个epoch的更完整的预测已经退化为不确定的部分检测。
这一进程说明了核心问题:合理的视觉猜测(“pizza”)→ 改进但不完美的检测 → 表现下降,遗漏了物体。第210个epoch的中间检查点虽然不完美,但比最终模型显示出更好的对象检测完整性。这表明继续训练导致模型变得更加保守和不确定,宁愿完全错过物体也不愿自信地误分类。
16、训练动态如何驱动这些行为
上述例子可能看起来像奇特的训练异常,但它们实际上是神经网络从数据中学习模式的必然结果。为了理解为什么模型学会了忽略靴子、把羊误认为狗以及对三明治变得不确定,我们需要查看不同机制的作用——从数据集范围内的模式学习到标记级别的损失计算。分解数学公式揭示了看似反直觉的行为实际上是模型完美优化的结果,只是不一定总是符合我们的期望。
16.1 学习数据集模式而非视觉现实
羊变狗的例子揭示了模型如何学习整个数据集的统计模式,这些模式可以覆盖个别图像的正确行为。这很可能是由于数千个示例中的上下文和行为模式学习造成的:
羊变狗的例子揭示了这种训练动态中最严重的问题。让我们看看交叉熵损失如何驱动模型学习统计模式而不是视觉特征:
# COCO数据集中可能发生的情况:
# 图像1: [动物 + 跳跃 + 在房子附近 + 窗口] → 标记为 "dog"
# 图像2: [动物 + 伸展 + 居住环境] → 标记为 "dog"
# 图像3: [动物 + 跳跃 + 在窗口附近 + 房子] → 标记为 "sheep"(我们的例子)
# 图像4: [动物 + 站立在后腿上 + 门口] → 标记为 "dog"
# ...数百个其他示例
# 模型学到:这些上下文特征 → "dog"(95%的时间)
# 统计模式覆盖了个别图像的正确性
模型发展出了与上下文模式和标签频率之间的关联。虽然我们特定的羊图像始终被标记为“sheep”,但模型在数据集中遇到了压倒性的证据,表明在这一上下文中——跳向房屋窗户的动物——几乎总是被标记为“dog”。COCO中的羊通常出现在牧场、田野或农场中,而不是在住宅环境中表现出“狗”的行为。
我们可以这样思考。为了举例,假设模型学到的上下文特征与我们定义的概念相吻合。我们可以将其表示为:
## 模型学到的上下文特征(简化):
contextual_features = [
near_window: 0.9,
residential_setting: 0.8,
jumping_behaviour: 0.9,
human_interaction_context: 0.7
]
# 数据集上的学习关联:
if near_window + jumping_behaviour + residential_setting > threshold:
most_likely_label = "dog" # 基于成千上万张图像的统计数据
# 这种上下文模式覆盖了单个图像的视觉特征
当然,模型不太可能学习到这些类型的功能,但这有助于说明问题。
这解释了为什么模型会学会错误:它不是基于视觉特征单独区分羊和狗,而是学习到在训练数据中,上下文和行为是标签的更强预测因素。一只行为像狗的羊,在类似狗的上下文中,会被预测为狗,因为成千上万张图像的统计数据表明如此。
16.2 交叉熵损失和序列对齐问题
伞和靴子的例子展示了交叉熵损失如何通过标记级别的计算直接驱动有问题的行为。为了回顾交叉熵损失的确切工作原理,可以在这里找到简要概述。
当第20个epoch的模型尝试预测靴子时,损失计算如下:
# 模型生成(第20个epoch):
predicted_sequence = [BOS, y1, x1, y2, x2, "shoe", EOS] # 尝试检测靴子
# 真实目标期望:
target_sequence = [BOS, y3, x3, y4, x4, "umbrella", EOS] # 仅标记伞
# 交叉熵损失计算: -log(p(target_token))
# 让我们计算每个位置的损失:
# 第0位: BOS标记
model_probs = [0.95, 0.01, 0.01, ...] # 模型对BOS非常确定
target_token = BOS # 真实目标希望BOS
loss_pos_0 = -log(0.95) = 0.05 # 低损失 - 两者都同意BOS
# 第1位: 第一个坐标 (y_min)
model_probs = [0.02, 0.03, 0.04, ..., 0.78, ...] # 模型对靴子的y坐标非常确定
target_token = y3 # 真实目标希望伞的y坐标
loss_pos_1 = -log(0.03) = 3.5 # 高损失 - 坐标不匹配
# 第2-4位: 剩余坐标
# 类似的不对齐,每个贡献约3-4
# 第5位: 类别标记 - 这里变得非常糟糕
model_probs = [0.01, 0.02, ..., 0.85, ..., 0.001, ...] # 非常确定它是鞋
target_token = "umbrella" # 真实目标希望伞
loss_pos_5 = -log(0.001) = 6.9 # 巨大的损失 - 类别完全错误
# 第6位: EOS标记
model_probs = [0.05, 0.9, 0.02, ...] # 模型认为序列应该结束
target_token = EOS # 真实目标也希望在此处结束
loss_pos_6 = -log(0.9) = 0.1 # 低损失 - 两者都同意结束
# 该序列的总损失:
total_loss = 0.05 + 3.5 + 3.2 + 3.8 + 3.4 + 6.9 + 0.1 = 20.95
请注意,这里我们假设模型正确预测了类别,即使在我们观察到的预测中并没有这样做。但是,我们可以注意到这实际上不会有任何区别!无论预测的类别是“shoe”还是“fire hydrant”,模型仍会因类别错误而受到惩罚。请记住,交叉熵损失中没有类别的相似性概念。
模型为尝试检测靴子而收到巨大的负面信号(损失约为21),即使它们很明显可见。但还有一个关键因素:生成对象的顺序对损失计算至关重要。
考虑如果模型先生成伞,然后生成靴子会发生什么:
# 替代生成顺序:
predicted_sequence = [BOS, y3, x3, y4, x4, "umbrella", y1, x1, y2, x2, "shoe", EOS]
target_sequence = [BOS, y3, x3, y4, x4, "umbrella", EOS]
# 损失计算:
# 位置0-5: 完美匹配!每个位置约0.05损失 = 0.3 总计
# 位置6: 模型预测y1(靴子y坐标),目标希望EOS
model_probs = [0.02, 0.01, ..., 0.78, ..., 0.001, ...] # 模型对y1非常确定
target_token = EOS # 真实目标希望序列结束
loss_pos_6 = -log(0.001) = 6.9 # 预测坐标而不是EOS的高损失
# 位置7+: 依赖于实现中的填充和截断处理
# 总损失: 0.3 + 6.9 + 填充惩罚 = ~8-10(远低于21)
注意:确切的损失取决于实现细节,如序列填充和长度归一化,但关键是完美匹配前6个标记会带来更低的损失,而不是从一开始就错配
这揭示了基于序列的对象检测的一个根本性问题:模型学会了优先生成“确定”的对象以最小化损失,即使其他对象同样可见。训练中的随机打乱帮助模型接触到不同的顺序,但在任何特定的前向传递中,顺序仍然极大地影响损失。
这创造了额外的学习偏差:
- 首先生成“通常标记”的对象(伞)
- 完全抑制“通常未标记”的对象(靴子)
- 避免生成可能导致序列错配的对象
损失函数为没有部分信用来承认靴子的位置存在对象,并且如果靴子在“正确”伞之前生成,还会严厉惩罚靴子检测。即使有随机打乱,这种顺序依赖性强化了模型对不完整标注模式的偏见。
16.3 交叉熵损失和学习的不确定性
三明治混淆显示了不确定性是如何从不一致的训练信号中产生的,而不是主动学习的。让我们看看损失计算:
# 第120个epoch: 模型自信地预测“pizza”
model_probs_pizza = [0.01, 0.02, ..., 0.78, ..., 0.05, ...] # 确信的猜测
target_token = "sandwich" # 真实目标希望三明治
loss_pizza = -log(0.05) = 3.0 # 中等损失 - 错误但合理的猜测
# 第300个epoch: 模型现在不确定且困惑
model_probs_confused = [0.15, 0.18, 0.12, 0.14, 0.11, ...] # 非常不确定
target_token = "sandwich" # 仍然希望三明治
loss_confused = -log(0.14) = 1.97 # 更低的损失但来自学习的不确定性
这显示了不确定性如何从不一致的训练信号中产生,而不是主动学习的。模型并没有发展出“习得无助”——它在冲突的数据下找到了最优解。当训练集中包含同时标记为“三明治”和“披萨”的脆饼时,模型学会了最小化所有这些不一致示例的损失的概率分布。
模型的概率分布现在更加平坦——它给“三明治”(0.14)、“披萨”(0.18)、“甜甜圈”(0.15)等分配了大致相等的概率。这不是作为一种策略学习的不确定性,而是对不一致训练标签的数学最优响应。损失层次是:自信正确(0.2)< 不确定(1.97)< 自信错误(3.0)。
16.4 训练进度:多种机制共同作用
上面的三个例子是为了说明我在训练运行中观察到的不同学习机制的共同作用:
早期训练(第1-50个epoch):
- 模型基于视觉特征尝试检测可见对象
- 通常标签错误但显示出良好的视觉识别
- 对模糊情况做出自信的猜测(将脆饼猜为披萨)
- 较高的训练损失,较好的视觉理解
中期训练(第50-150个epoch):
- 模型开始学习数据集范围内的统计模式(羊→狗转换)
- 交叉熵损失驱动未标记对象的抑制(靴子消失)
- 开始对不一致标签发展学习的不确定性
- 损失下降,出现学习的偏差
后期训练(第150个epoch以上):
- 模型紧密复制标注模式(伞是,靴子否)
- 统计模式覆盖了单个图像的正确性(羊作为狗)
- 交叉熵损失在模糊情况下创建保守的不确定性
- 最低的训练损失,完美复制数据集偏差
模型学会了预测数据集模式而不是检测对象。它发展出多种类型的学习关联:上下文模式(羊→狗)、标注完整性模式(靴子通常未标记)和不确定性模式(模糊的食物项目)。
17、COCO标注的挑战
在进行本项目的错误分析时,我再次意识到COCO数据集中存在多少潜在问题;这些特定示例揭示了影响所有检测模型的COCO数据集的系统性问题,但在Pix2Seq中尤为明显。虽然COCO是目前最广泛使用的对象检测数据集之一,但它存在标注问题,这些问题很可能存在于许多人工标注的检测数据集中:
- 缺失的对象标注:背景或次要对象通常被标注者忽略。我的伞/靴子示例完美地说明了这一点——靴子明显可见但未被标注,代表了标注者专注于“主要”对象并忽略次要对象的广泛模式。
- 基于尺度的标注偏差:小对象(面积 < 3²² 像素)相比大对象被系统性地低估。标注者经常忽略或跳过明显可见的小对象,比如我三明治示例中的小背景物品(键盘、鼠标、遥控器)。
- 不一致的分类模式:外观相似的对象根据上下文被不一致地标注。我的羊→狗示例揭示了模型如何学习统计模式而不是视觉差异——可能是因为COCO中在模糊上下文中的动物被不同的标注者以不同方式标注。
- 复杂场景和标注疲劳:在有许多对象的图像(繁忙的街道场景、杂乱的房间)中,标注质量显著下降。标注过程后期的对象更有可能被遗漏,因为标注者的注意力减弱。这解释了为什么模型学会保守地检测多个对象。
- 类别边界模糊性:相似类别的区分在标注者之间差异很大。“杯子”与“马克杯”、“沙发”与“椅子”、“卡车”与“公共汽车”等的区分创建了系统性不一致。
- 上下文标注偏差:在“不寻常”上下文中的对象更可能被忽略或错误标注。一只在房屋窗户前跳跃的羊(类似狗的行为)通过学习的统计模式被处理,而不是仔细的视觉分析。
- 遮挡和部分对象处理:部分遮挡的对象不一致地被标注。有时一个人的头部在车后会被标注,有时则不会,这取决于标注者的判断和疲劳。
- 时间和文化不一致性:COCO由不同团队在多年内标注,导致标准的变化。此外,标注者可能更可能注意到并正确标注他们文化背景熟悉的对象。
- 图像边界效应:部分被截断在图像边缘的对象不一致地处理——有时被标注,有时被忽略,从而形成关于何时检测截断对象的学得偏差。
因此,一个以最小化交叉熵损失为目标的模型学会了这些标注模式作为特征,而不是错误,例如:
- “靠近伞的靴子通常不被标注”
- “在某些上下文中,像羊的动物更可能被标注为狗”
- “小背景对象不一致地被标注,所以应保持低置信度”
- “复杂场景中超过前几个对象通常被忽略”
- “部分遮挡的对象取决于上下文进行标注”
- “当对模糊食物类别不确定时,避免自信预测”
这些不是随机错误——它们是系统性学得的偏差,使模型通过匹配标注模式来实现较低的训练损失,而不是检测所有可见对象。模型基本上学会了复制人类标注者的行为,包括他们的偏差、疲劳模式和不一致性。
这些问题源于大规模人工标注的基本挑战:主观性、疲劳、不一致的指导方针以及彻底标注复杂视觉场景的固有难度。
17.1 为什么我继续使用COCO尽管存在这些问题
尽管存在这些重大挑战,我还是选择继续使用COCO有两个原因:
- 与论文的一致性:为了公平评估该方法并与报告的结果进行比较;提供一个我努力的目标
- 行业标准:COCO仍然是最广泛使用的基准,使结果与其他方法可比
但这次经历强化了在生产环境中,你希望仔细审计你的训练数据质量并可能使用更完整的标注或替代训练策略。
18、为什么这对Pix2Seq特别有问题
虽然这个问题影响了所有对象检测模型,但对Pix2Seq来说,它尤其突出有几个关键原因:
18.1 序列到序列 vs 空间匹配
传统的对象检测方法使用灵活的空间匹配算法(DETR中的匈牙利匹配,YOLO/R-CNN中的IoU匹配),可以在一定程度上缓解标注问题。如果预测的框与真实框有显著重叠,即使类别略有不同或框坐标不完全对齐,也可能在训练中被视为“正确”匹配。
Pix2Seq,像语言模型一样,使用逐标记的序列匹配和交叉熵损失。要么序列中的每个标记都完全匹配,否则就不匹配。当模型预测一个未标注的对象时,序列中的每个后续标记都会错配,造成级联惩罚。
18.2 标注质量 vs 自监督
关键区别不在于训练方法——语言模型使用相同的直接序列匹配方法。区别在于数据质量和系统性遗漏:
- 语言建模:虽然有多种有效的文本延续方式,但训练数据代表了一条实际写就的有效路径。模型从完整的自然序列中学习。
- 对象检测:训练数据有系统性遗漏——图像中确实存在的对象但未被标注者标注。
当语言模型遇到模糊的延续时,它学会适当表示不确定性。但它们很少遇到训练数据中明显可见的信息被系统性遗漏的情况,就像检测标注中缺失的对象那样。
18.3 没有内置的对标注缺口的鲁棒性
空间检测方法有一些自然的鲁棒性——稍微错误的框可能仍然有足够的重叠被认为是正确的。序列方法没有这种灵活性。序列中第一个未标注的对象会影响所有后续标记位置。
然而,这表明在语言建模中解决类似问题的相同扩展方法也应该在这里起作用。如果有足够的规模和高质量的标注,Pix2Seq应该像现代LLM一样对训练数据中的不一致具有鲁棒性。
18.4 机器学习中的一个更广泛的教训
这一观察揭示了机器学习中的一个基本原则:优化损失函数并不等于优化期望行为。当你的真实数据存在系统性偏差或不完整时,标准训练程序会学习并放大这些问题,而不是纠正它们。
Pix2Seq中的中间模型现象特别清晰,因为模型的预测是直接可解释的。与CNN检测器中的密集特征图不同,我们可以轻松观察模型在训练过程中的行为变化,并确切知道它在学习什么。
18.5 前进的道路:高质量的扩展
鉴于语言建模的成功,自然的问题是这些问题是可以通过简单收集更多训练数据来解决的。答案是微妙但最终乐观的。
使用当前数据质量进行扩展可能会加强现有的偏差。更多具有相同标注缺口的图像会教模型更自信地忽略未标注的对象,更多住宅环境下狗的例子会加强导致羊误判的上下文关联。
然而,使用高质量、完整的标注进行扩展可能会解决大部分这些问题——而语言建模的证据表明这种方法是有效的。当在多样且高质量的数据上进行训练时,LLM对不一致和边缘案例表现出惊人的鲁棒性。
序列增强已经有所帮助:有趣的是,Pix2Seq的许多设计选择——假对象、标签损坏、噪声框训练——已经旨在缓解这些类型的标注问题。模型学会了区分真实对象和合成噪声,这增强了对不完整标注的鲁棒性。
交叉熵损失不是问题:与其围绕损失函数进行工程(这在语言建模中表现得非常好),前进的道路可能是:
- 用数量级更多的训练数据扩展该方法
- 大幅提高标注质量——更完整的标注,一致的指南
- 利用现有的序列增强框架,该框架已经提供了鲁棒性
语言建模的证据表明,只要有足够规模和数据质量,这些上下文学习问题会变成特性而不是缺陷——模型会学习到羊实际上出现在住宅环境中的完整分布,靴子应该被标注,以及如何处理真正模糊的情况。
真正限制并不是训练目标或模型架构——而是我们尚未将推动语言建模变革的扩展定律应用到对象检测领域,并且使用适当高质量的数据。
标注质量问题是计算机视觉研究中的一个“肮脏秘密”,值得在文献中得到更多关注。下次你看到训练损失下降并自动假设模型正在变好时,请记住:有时最有洞察力的模型还没有学会做出“正确”的错误。
19、结论:重新思考基础的优雅
从零开始重新实现Pix2Seq是一次罕见的经历,从根本上改变了我对问题领域的看法。这最初是对一种替代对象检测方法的好奇心,变成了对表示、简洁性和重新思考基本假设的力量的深入课程。
原始论文的突破不仅仅是关于对象检测——它证明了表示比架构更重要。通过找到将视觉理解表达为语言的正确方式,作者解锁了计算机视觉任务的几十年NLP研究。现在回过头来看,Pix2Seq显得很有预见性。我们正处在GPT-4o、带视觉的Claude和LLaVA的时代——所有都遵循同样的核心见解,即视觉和语言比看起来更相似。始于NLP的序列到序列革命现在正在AI的每个角落发生。
但或许最深刻的教训是关于激进简化的力量。在一个痴迷于架构复杂性的领域中,Pix2Seq提醒我们,最具革命性的突破往往来自于问:“如果我们完全以不同的方式思考呢?” 作者证明了有些最重要的研究不是通过添加复杂性,而是通过找到去除它的方法。在一个常常沉迷于架构创新和性能优化的领域中,Pix2Seq提醒我们,有时最具革命性的方法也是最明显的——一旦有人向你展示如何看到它。
下次你面对一个复杂的工程问题时,考虑问自己:“如果我们完全以不同的方式表示呢?” 答案可能会改变一切。
本文讨论的完整PyTorch实现可在 GitHub上找到,包括忠实复现和现代变体,所有标记化逻辑、序列增强、约束生成和全面评估代码。它旨在作为一个干净、文档齐全的参考,用于理解和扩展Pix2Seq方法。
原文链接:Rethinking Object Detection as Language Modelling: Lessons from Reimplementing Pix2Seq
汇智网翻译整理,转载请标明出处