用 MLX 进行机器学习研究

大多数机器学习研究都在运行 PyTorch 的 NVIDIA 硬件上进行。工具链成熟,生态系统庞大,如果你需要从 Transformer 中提取内部表示来进行探测实验,你会使用 HuggingFace 和 CUDA GPU。这是默认路径,而且效果很好。

我选择了一条更便宜的路。过去几个月里,我一直在 Apple Silicon 上运行一个实验流程,用于 LLM 行为研究,MLX 是我为一篇正在准备中的论文使用的主要研究平台。

本文讲述的是实际操作中的情况:MLX 在哪些方面足以用于研究,在哪些方面需要针对其边缘情况进行工程处理,以及我学到的关于将消费级硬件作为严肃实验平台的经验。

1、研究问题简述

我正在研究 LLM 在不同类别的题目上的行为差异,这些类别中模型与底层知识的关系各不相同。细节将在论文中阐述,但这里的工程要求才是关键:一个经过精心筛选题目的基准测试系统,一个在多种提示条件下测试模型响应的评估系统,以及每层激活提取,用于训练线性探测模型,以了解模型在每一层产生的内容。整个流程在配备 32GB 统一内存的 Mac Mini M4 上运行,使用来自 mlx-community 仓库之一的 4 位量化 Mistral 7B。

2、为什么选择 MLX

我是一个没有访问 CUDA 集群的独立研究员,而 Mac Mini 是我拥有的设备。

但不太明显的答案是,这个约束条件实际上变成了一个有用的强制函数。当你的推理预算只有一台消费级机器时,你会更深入地思考实验设计。你不能通过投入更多 GPU 来粗暴地解决设计糟糕的筛选过程。每个候选题目都会通过十次推理 passes,使用确定性种子,浪费的 token 会消耗你无法并行化掉的实时时间。

你得到的是一个默认可复现的设置,带有固定种子和可重播的筛选运行。硬件约束要求你在设计实验时更加自律。

值得讨论的是 4 位量化的权衡。量化会压缩模型的权重,而这种压缩会损失信息。4 位模型中的表示比全精度或 8 位版本更嘈杂:激活中的细微影响可能会被抹去,模型的行为模式略有退化。对于依赖于细粒度表示差异的探测实验,8 位或全精度权重会更可取。在 32GB 统一内存上,4 位的 7B 模型(约 4GB)可以舒适地与提取流程并存;在 8 位(约 8GB)时仍然可以容纳,但空间较少;在全精度(约 14GB)时会非常紧张。

实际上,从 4 位开始的理由是,如果一个影响强到足以通过量化噪声显现出来,那么它很可能不是伪影。在 4 位激活上实现明显分离的探测模型测量的是某种稳健的东西。如果在 4 位时影响消失,你无法判断它是根本不存在还是只是低于噪声底,这是一个真正的限制。但对于初始验证,4 位量化起到了保守过滤器的作用:能够幸存下来的影响值得以更高精度进一步研究。

3、提取隐藏状态:钩子

核心工程挑战是从 MLX 模型的每一层提取表示。HuggingFace transformers 拥有 output_hidden_states=True 作为一级参数。MLX-LM 没有。

我的解决方案是临时替换内部模型的 __call__ 方法。MLX-LM 模型具有两层结构:一个拥有语言模型头(lm_head)的外部包装器,和一个拥有嵌入层、transformer 层和最终 norm 的内部 transformer(model.model)。这个钩子手动遍历这些层并收集每一层的输出。

这是完整的实现,包括花费很长时间才弄对的兼容性变通方法:

import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load

model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

def extract_hidden_states(text: str) -> mx.array:
    """Extract per-layer hidden states. Returns (n_layers+1, seq_len, hidden_dim)."""
    tokens = tokenizer.encode(text)
    input_ids = mx.array([tokens])  # (1, seq_len)
    inner_model = model.model
    collected_states: list[mx.array] = []
    # Save the original __call__ at the class level.
    original_call = inner_model.__class__.__call__
    def patched_call(self_inner, x, cache=None, mask=None, **kwargs):
        # Embedding - models name this layer differently.
        if hasattr(self_inner, 'embed_tokens'):
            h = self_inner.embed_tokens(x)
        elif hasattr(self_inner, 'embedding'):
            h = self_inner.embedding(x)
        else:
            raise AttributeError("Cannot find embedding layer")
        # Causal mask.
        if mask is None and h.shape[1] > 1:
            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
            mask = mask.astype(h.dtype)
        # Iterate through transformer layers.
        for i, layer in enumerate(self_inner.layers):
            try:
                if cache is not None:
                    try:
                        h = layer(h, mask=mask, cache=cache[i])
                    except TypeError as te:
                        # Cache signature varies across architectures.
                        if "cache" in str(te) or "unexpected keyword" in str(te):
                            h = layer(h, mask=mask)
                        else:
                            raise
                else:
                    h = layer(h, mask=mask)
            except Exception as e:
                raise RuntimeError(f"Layer {i} forward pass failed: {e}") from e
            # Some layers return (hidden_state, cache_update) tuples.
            if isinstance(h, tuple):
                h = h[0]
            collected_states.append(h)
        # Final norm.
        h = self_inner.norm(h)
        collected_states.append(h)
        return h  # Outer model applies lm_head
    try:
        inner_model.__class__.__call__ = patched_call
        _logits = model(input_ids)
        mx.eval(_logits)  # Force computation before collecting.
        result = mx.stack(collected_states, axis=0)
        result = result[:, 0, :, :]  # Remove batch dim.
        mx.eval(result)  # Force evaluation before return.
    finally:
        inner_model.__class__.__call__ = original_call
    return result

关于为什么代码看起来如此奇怪,有几点需要注意。

它在类级别而不是实例级别替换 __call__。这很重要,因为 MLX-LM 的外部模型调用 self.model(x),这会通过类的 __call__ 分发。修补实例会被静默忽略。

这两个 mx.eval() 调用比看起来更重要。MLX 使用惰性求值:计算在需要结果之前不会执行。如果没有第一个 mx.eval(_logits),当我们尝试堆叠收集的状态时,计算图尚未实现。如果没有第二个 mx.eval(result),堆叠的张量可能会引用在下一次提取调用时被覆盖的图节点。这是最难发现的 bug,因为它间歇性地出现,只在特定的内存压力下,并产生静默错误值而不是错误。

try/finally 很重要,原因与这种修补总是有风险的相同:如果前向传播抛出异常,你需要恢复原始方法。没有它,一次失败的提取会使模型处于不可用状态。

所有防御性代码 — 嵌入名称回退(embed_tokens vs embedding)、缓存签名处理、元组返回检查 — 的存在是因为社区模型在内部约定上各不相同。这些都不是 MLX 本身的 bug — 它们是生态系统尚未完全标准化其模型接口的后果。每个变通方法都花费了一个下午来诊断和几行代码来修复。

4、从激活到探测:完整流程

从单个输入获取层输出是大部分工程工作所在。其他一切都更常规,但 MLX 和 numpy 之间的内存管理值得记录。

对于探测实验,你需要数据集中每个问题的每一层的最后一个 token 表示。每个问题每层一个向量,你要比较一个简单的分类器(逻辑回归)在每一层独立区分类别的效果。

问题是一次处理一个的。MLX 对于可变长度输入的批处理效果不佳,因为不存在像 HuggingFace 提供的填充/注意力掩码基础设施,所以每个问题都有自己的前向传播:

import numpy as np

def extract_batch(model, questions, prompt_prefix=""):
    """Extract last-token hidden states for a list of questions."""
    all_states = []
    for i, q in enumerate(questions):
        text = prompt_prefix + q.text
        # Returns mx.array of shape (n_layers + 1, hidden_dim).
        states_mx = extract_last_token(model, text)
        # Convert to numpy immediately to free MLX memory.
        states_np = np.array(states_mx)
        all_states.append(states_np)
    # Stack in numpy, not MLX.
    return np.stack(all_states, axis=0)  # (n_questions, n_layers+1, hidden_dim)

def extract_last_token(model, text):
    """Convenience: extract full hidden states, take last token."""
    all_states = extract_hidden_states(text)
    return all_states[:, -1, :]  # (n_layers + 1, hidden_dim)

在每次传递后立即转换为 np.array() 是保持内存有界的方法。没有它,MLX 会在其内存池中累积数百个张量,在 32GB 机器上处理几百个问题,该池会被填满。转换为 numpy 会释放 MLX 内存以便在下一个前向传播中重用。

一旦你拥有形状为 (n_questions, n_layers+1, hidden_dim) 的 numpy 数组,探测就是直接的 scikit-learn:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score

def probe_layer(X, y, n_folds=5):
    """Train a binary probe at a single layer. Returns mean AUROC."""
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    aurocs = []
    for train_idx, test_idx in skf.split(X, y):
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X[train_idx])
        X_test = scaler.transform(X[test_idx])
        clf = LogisticRegression(max_iter=1000, class_weight="balanced")
        clf.fit(X_train, y[train_idx])
        proba = clf.predict_proba(X_test)[:, 1]
        aurocs.append(roc_auc_score(y[test_idx], proba))
    return np.mean(aurocs)

# Run independently at every layer.
n_layers = hidden_states.shape[1]
for layer in range(n_layers):
    X = hidden_states[:, layer, :]  # (n_questions, hidden_dim)
    auroc = probe_layer(X, labels)
    print(f"Layer {layer}: AUROC = {auroc:.3f}")

在每一层独立运行探测会产生一条曲线,显示模型中类别之间的区分在哪里最强。在我的实验中,信号在早中期层达到峰值,并在向最终层减弱,这告诉你一些关于模型如何处理信息的信息(虽然解释在论文中,而不是这里)。

所有内容都保存和加载为压缩的 .npz 文件,所以你运行一次提取,可以多次探测而无需重复推理。在 Mac Mini 上,为几百个问题提取层输出大约需要二十分钟。训练所有探测需要不到一秒钟。

5、大规模确定性筛选

基准测试构建为每个候选题目运行十次推理 passes,温度 0.7,使用确定性种子。每个种子计算为 base_seed + (question_index * n_runs) + run_index,每个类别有不同的基础种子。这意味着可以通过使用相同种子重新运行来完全重现任何单个筛选结果。

MLX 的随机数生成很好地支持了这一点。在每次随机生成调用之前,代码固定 mx.random.seed(seed),这会对相同输入产生相同的 token 采样。对于零温度条件(在评估阶段使用),采样器完全绕过 RNG 并对 logits 取 mx.argmax

筛选标准因类别而异。有些要求高准确性,有些针对特定准确度范围,还有些要求一致的失败。每个筛选函数应用自己的过滤器,但它们都共享相同的确定性种子基础设施。

早期试点结果表明,探测模型捕获了模型内部表示中的分离,信号集中在早期层而不是最终层。信任探测实验的数字取决于它们底层的实验基础设施。

6、解析 Mistral 7B 输出:比看起来更难

如果你正在通过 mlx-lm 进行 Mistral 7B 的评估工作,你会在输出解析上花费比预期更多的时间。其响应在格式上变化足够多,单个正则表达式无法覆盖它们,而失败模式是静默的:你得到 None 而不是分数,或者更糟,你从响应中提取了错误的数字。

置信度解析很好地说明了这一点。我的提示以 Confidence (0-100): 结尾,模型通常首先以第一个 token 的形式响应分数,然后是换行符和解释。通常。有时分数出现在解释之后。有时格式化为分数(85/100)或百分比(85%)。有时模型重述标签(Confidence: 85)。有时它将数字埋在句子中。

我用一个五优先级策略来处理这个问题,按顺序尝试:

import re

def parse_confidence(text: str) -> float | None:
    """Extract a confidence score (0-100) from Mistral 7B output.
    Priority order:
      1. Leading bare number  ("85\n\nThe answer is...")
      2. Explicit label        ("Confidence: 85")
      3. Fraction              ("85/100")
      4. Percentage            ("85%")
      5. Trailing bare number  (last 80 chars of response)
    """
    stripped = text.strip()
    # Priority 1: Score as first token.
    lead_match = re.match(r"^(\d{1,3})\b", stripped)
    if lead_match:
        val = int(lead_match.group(1))
        if 0 <= val <= 100:
            return val / 100.0
    # Priority 2-4: Explicit patterns anywhere in the response.
    patterns = [
        r"[Cc]onfidence[:\s]*(\d{1,3})",
        r"(\d{1,3})\s*/\s*100",
        r"(\d{1,3})%",
    ]
    for pat in patterns:
        match = re.search(pat, text)
        if match:
            val = int(match.group(1))
            if 0 <= val <= 100:
                return val / 100.0
    # Priority 5: Any number in the tail of the response.
    tail = text[-80:]
    numbers = re.findall(r"\b(\d{1,3})\b", tail)
    for n_str in reversed(numbers):
        val = int(n_str)
        if 0 <= val <= 100:
            return val / 100.0
    return None

优先级顺序在这里很重要。优先级 1(前导裸数字)捕获约 80% 的响应。当模型将分数放在第一位时,没有歧义,也没有从解释文本中抓取年份或页码的风险。优先级 5(尾随数字)是后备,它最可能提取错误的值,因为模型有时在推理中提到年份、数量或其他数字。将搜索限制在最后 80 个字符可以减少误报,但不能消除它们。

答案检查有其自己的问题。Mistral 7B 的响应在冗长程度上各不相同,所以我规范化模型输出和参考答案(小写,删除冠词和标点符号,折叠空白),然后检查规范化参考是否出现在规范化输出中的任何位置:

def normalise_answer(text: str) -> str:
    text = text.lower().strip()
    text = re.sub(r"\b(a|an|the)\b", " ", text)
    text = re.sub(r"[^\w\s]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def check_answer(model_output: str, reference_answers: list[str]) -> bool:
    model_norm = normalise_answer(model_output)
    for ref in reference_answers:
        ref_norm = normalise_answer(ref)
        if ref_norm and ref_norm in model_norm:
            return True
    return False

子字符串匹配而不是精确匹配是一个有意的选择。Mistral 7B 几乎从不给出裸答案;它将它们包装在解释中。"法国的首都是巴黎,这……"需要与参考答案 "Paris" 匹配。短参考答案如果作为无关单词的子字符串出现,可能会产生误报,但在实践中,参考答案足够具体,这并没有成为问题。

拒绝检测更简单,但仍然是特定于模型的。Mistral 7B 使用有限的拒绝词汇:"I don't know"、"I'm not sure"、"I cannot determine" 和大约十几种变体。针对小写输出的短语列表可以可靠地捕获这些。门控检查要求解析器从至少 90% 的置信度条件响应中返回分数;低于此值,校准数字无法信任。我的实现达到约 96%。

这些都不复杂。值得记录的是,这些解析器位于评估中每个指标的上游。如果置信度解析器误读 5% 的分数,期望校准误差会偏移。如果答案检查器有系统性误报,准确性数字会膨胀。这些解析器需要自己的测试套件,根据我的经验,它们需要在评估代码的其余部分之前,而不是之后。

7、我希望从生态系统获得什么

MLX 足够好可以构建,这比大多数框架在其主要社区之外的研究用途所能管理的要好。但是存在摩擦点。

提取层表示的 API 将使每个研究人员免于编写自己的钩子。像 HuggingFace 的 output_hidden_states 参数那样的东西,内置到生成路径中,将是研究用户最有用的补充。

跨社区集合的标准化模型接口将减少嵌入名称、缓存签名和返回类型所需的防御性编码。这部分是关于模型如何转换的,部分是 API 设计问题。

更好地记录惰性求值模型,特别是其与有状态操作如激活收集的交互,将防止静默失败的一类 bug,就像我上面描述的那些。当前文档假设一个专注于推理的用户,不需要拦截中间计算。

8、所以它有效吗?

我最初的问题是消费级 Apple Silicon 是否可以支持可发表的机器学习研究。大多数情况下,是的。工具链不如 PyTorch 成熟,你会编写 CUDA 用户不需要的兼容性代码,你会花费整个下午调试源于为不同主要用例做出的合理设计决策的行为。

筛选过程以完全可复现性处理数百个候选,激活输出干净,探测模型训练。GPU 利用率在构建期间达到 100%,这是正确的瓶颈:计算,而不是工程。

对于拥有 Apple Silicon 且没有 CUDA 访问权限的研究人员,MLX 是一个可行的研究平台。它需要工程投资才能这样使用,但工程是可处理的,结果是站得住脚的。


原文链接: Hidden States on a Mac Mini: Using MLX for Real ML Research

汇智网翻译整理,转载请标明出处