强化蒸馏简明教程
一个小型的PyTorch模型如何仅通过+1/-1反馈从大型语言模型(LLM)中学习?这就是强化蒸馏的作用。
AI蒸馏正在悄然成为现代机器学习最重要的概念之一。这是将较小的模型教成像较大的模型一样行为的过程——将智能、判断和推理压缩到一个紧凑的网络中。 在本文中,我们将探讨一个有趣但意义重大的想法: 使用本地LLM(通过Ollama)作为教师,利用强化学习训练一个小型的PyTorch模型。 教师只提供+1或-1的反馈,学生会随着时间推移而改进。
这不仅仅是一个玩具。 这是以下内容的最简单版本:
- 知识蒸馏
- 偏好对齐
- 基于AI反馈的强化学习(RLAIF)
- 行为克隆
- LLM引导的策略学习
全部浓缩在约200行代码中。
让我们一起了解架构、背后的想法以及为什么它代表了一种新的基于强化学习的模型蒸馏形式。
1、强化蒸馏
传统的蒸馏工作方式如下:
- 大型教师模型输出logits或概率分布。
- 小型学生模型学习模仿这些输出。
- 训练是监督的(交叉熵、KL散度等)
但我们的方法不同。 不是直接给出答案,而是让教师给出判断:
- “你回答正确” → +1
- “你回答错误” → −1
这将训练转化为一个强化学习循环,其中:
- 学生学习一种将问题映射到正确动作的策略
- 教师充当奖励模型
正确性信号被蒸馏到学生的参数中
这是一种轻量级的: 基于强化学习的模型蒸馏 (也称为奖励蒸馏、策略蒸馏或偏好蒸馏)。
这是OpenAI的RLHF流程背后的相同原理,只是简化用于实验。
2、目标:一个能学习回答问题的小型代理
我们创建了一个小型的PyTorch模型,必须回答如下问题:
- “法国的首都是哪里?”
- “说你好。”
- “给我讲个笑话。”
但我们不会给它答案,而是:
- 让它猜测
- 让本地LLM(Ollama)评估这个猜测
- 使用强化学习更新模型
随着时间的推移,学生变得一致且正确。
3、架构
3.1 学生模型(小型神经网络)
一个简单的嵌入→密集→ReLU→密集网络,输出10个可能答案的分布。
class ResponsePolicy(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(VOCAB_SIZE, 32)
self.fc = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, ANSWER_SIZE)
)
该模型不生成文本。 它从预定义列表中选择一个动作:
ANSWERS = [
"yes", "no", "maybe", "I don't know",
"Paris", "Mumbai",
"joke: why did the chicken cross the road?",
"greeting: hello!", "bye!", "thank you!"
]
这使得RL 稳定且高效。
3.2 教师:Ollama本地LLM
教师模型(例如,在Ollama上运行的llama3.2:1b)评估学生输出。
它接收:
- 问题
- 学生答案
- 真实答案
并返回+1或−1,没有任何解释:
def get_llm_reward(question, agent_answer, ground_truth):
prompt = f"""
You are a strict evaluator...
Respond with exactly +1 or -1.
"""
这很强大,因为教师带来了:
- 世界知识
- 语义理解
- 上下文意识
……而无需我们手动提供标签。
4、强化蒸馏步骤
每个训练步骤:
- 输入问题 → 学生预测分布
- 从分布中采样一个答案
- 向Ollama询问+1或−1
- 应用REINFORCE损失:
rl_loss = -log_prob * reward
如果奖励为正 → 应用额外的监督“锁定”损失:
if reward > 0:
ce_loss = ce_loss_fn(logits, target)
loss += 2.0 * ce_loss
这种“混合强化+监督”训练稳定了学习并促进了更快的收敛。
5、训练循环
模型反复采样训练问题之一:
("what is the capital of france?", "Paris")
("say hello", "greeting: hello!")
...
每次迭代都会更新学生直到运行奖励提高。 一个简单的启发式方法会在性能提升时触发模型保存。
6、评估:贪婪模式
训练后,模型从采样切换到选择最高概率的答案。
idx = torch.argmax(probs)
现在你可以通过问同样的问题来测试它——它应该始终正确回答。
为什么这是新型蒸馏
此设置避免了:
- logits
- 教师概率
- 交叉熵来模仿教师
相反,学生收到:
- 基于教师判断的二进制奖励
- 隐含的偏好塑造
- 语义正确性信号
- 基于教师的强化信号
- 小型模型不模仿教师知道的内容。
- 它学习模仿教师奖励行为的方式。
这就是基于强化学习的模型蒸馏的本质。
7、为什么这很重要
你可以训练领域专家
- 教小型模型一个单一领域(例如板球规则、医疗分诊、产品FAQ)。
- 不需要标注数据集
- 教师LLM生成所有正确性信号。
- 小型模型变得出乎意料的好,一个两层的神经网络在教师指导下可以表现得像有智能一样。
- 高效
- 训练轻量且完全本地化。
- 代理系统的基础。
这构成了“自我改进代理”的基础 —— 小模型由大LLM监督。
完整代码:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import ollama # pip install ollama
# ----------------------------
# 固定答案选项(策略输出空间)
# ----------------------------
ANSWERS = [
"yes",
"no",
"maybe",
"I don't know",
"Paris",
"Mumbai",
"joke: why did the chicken cross the road?",
"greeting: hello!",
"bye!",
"thank you!"
]
ANSWER_SIZE = len(ANSWERS)
VOCAB_SIZE = 2000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "response_policy_ollama_nocache.pt"
OLLAMA_MODEL_NAME = "llama3.2:1b" # 或者你拥有的任何本地模型
# ----------------------------
# 简单的训练数据集
# (question, ground_truth_answer)
# ----------------------------
TRAIN_QA = [
("what is the capital of france?", "Paris"),
("what is the capital of india?", "New Delhi"),
("say hello", "greeting: hello!"),
("tell me a joke", "joke: why did the chicken cross the road?"),
("say goodbye", "bye!"),
]
# ----------------------------
# 策略网络
# ----------------------------
class ResponsePolicy(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(VOCAB_SIZE, 32)
self.fc = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, ANSWER_SIZE)
)
def forward(self, tokens):
# tokens: [batch, seq_len]
e = self.embed(tokens) # [batch, seq_len, emb]
x = torch.mean(e, dim=1) # [batch, emb]
logits = self.fc(x) # [batch, ANSWER_SIZE]
return logits
policy = ResponsePolicy().to(device)
optimizer = optim.Adam(policy.parameters(), lr=1e-3)
ce_loss_fn = nn.CrossEntropyLoss() # 用于强烈的“这是正确的”更新
# ----------------------------
# 分词器
# ----------------------------
def tokenize(text: str) -> torch.Tensor:
tokens = text.lower().split()
ids = [hash(w) % VOCAB_SIZE for w in tokens]
return torch.tensor([ids], dtype=torch.long)
# ----------------------------
# 通过Ollama的奖励模型
# ----------------------------
def get_llm_reward(question: str, agent_answer: str, ground_truth: str) -> int:
"""
询问Ollama,判断代理的回答是否与地面真相语义上正确
返回+1或-1。
"""
prompt = f"""
你是一个严格的评估者。
任务:
- 你得到:
- 一个用户的问题。
- 代理的回答。
- 地面真相的正确答案。
规则:
- 如果代理的回答在语义上正确并且与地面真相的意义匹配,
请精确地回复:+1
- 否则,请精确地回复:-1
不要解释,不要额外文本。只输出+1或-1。
问题: {question}
代理回答: {agent_answer}
地面真相答案: {ground_truth}
"""
resp = ollama.chat(
model=OLLAMA_MODEL_NAME,
messages=[
{"role": "user", "content": prompt.strip()}
],
)
content = resp["message"]["content"].strip()
if "+1" in content:
return 1
elif "-1" in content:
return -1
else:
# 如果模型表现不佳的备用方案
return -1
# ----------------------------
# 训练步骤:RL +(可选)额外的监督推动
# ----------------------------
def train_step(question: str, ground_truth: str):
# 1) 前向传递
tokens = tokenize(question).to(device)
logits = policy(tokens) # [1, ANSWER_SIZE]
probs = torch.softmax(logits, dim=-1)
dist = Categorical(probs)
idx = dist.sample() # [1]
action_idx = idx.item()
agent_answer = ANSWERS[action_idx]
log_prob = dist.log_prob(idx) # 标量
# 2) 从Ollama获取奖励
reward = get_llm_reward(question, agent_answer, ground_truth)
# 3) 构建损失
# RL部分(REINFORCE):鼓励或阻止这个采样的动作
rl_loss = -log_prob * reward # reward=+1 => 推动log_prob上升;reward=-1 => 推动下降
loss = rl_loss
# 4) 如果奖励为正,则将其视为“这是正确的”
# 并添加更强的监督项以真正锁定它。
if reward > 0:
target = torch.tensor([action_idx], dtype=torch.long, device=device)
ce_loss = ce_loss_fn(logits, target) # 使这个动作成为最高的logit
# 如果需要,可以更强烈地加权:
loss = loss + 2.0 * ce_loss # 2.0 是超参数
# 5) 优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
return agent_answer, reward, loss.item()
# ----------------------------
# 保存/加载
# ----------------------------
def save_model(path=MODEL_PATH):
torch.save(policy.state_dict(), path)
print(f"[+] 模型已保存到 {path}")
def load_model_if_exists(path=MODEL_PATH):
if os.path.exists(path):
policy.load_state_dict(torch.load(path, map_location=device))
policy.train()
print(f"[+] 从 {path} 加载了现有模型")
else:
print("[!] 没有找到现有模型,从头开始。")
# ----------------------------
# 训练循环
# ----------------------------
def train_with_ollama(num_steps: int = 200):
load_model_if_exists()
running_reward = 0.0
best_running_reward = -999.0
for step in range(1, num_steps + 1):
question, ground_truth = random.choice(TRAIN_QA)
agent_answer, reward, loss = train_step(question, ground_truth)
running_reward = 0.95 * running_reward + 0.05 * reward
if step % 10 == 0:
print(
f"Step {step:4d} | Q: {question} | agent: {agent_answer} | gt: {ground_truth} "
f"| reward={reward:+d} | running_reward={running_reward:+.3f} | loss={loss:.4f}"
)
# 简单的保存启发式方法:每当显著提升running_reward时
if running_reward > best_running_reward + 0.05:
best_running_reward = running_reward
save_model()
# ----------------------------
# 评估:训练后的贪婪回答
# ----------------------------
def eval_greedy():
if not os.path.exists(MODEL_PATH):
print("未找到训练好的模型,先进行训练。")
return
policy.load_state_dict(torch.load(MODEL_PATH, map_location=device))
policy.eval()
print("\n[Eval] 训练后的贪婪回答:\n")
for question, ground_truth in TRAIN_QA:
tokens = tokenize(question).to(device)
with torch.no_grad():
logits = policy(tokens)
probs = torch.softmax(logits, dim=-1)
idx = torch.argmax(probs, dim=-1).item()
answer = ANSWERS[idx]
print(f"Q: {question}\n → Agent: {answer} (GT: {ground_truth})\n")
if __name__ == "__main__":
print("使用Ollama作为奖励训练RL代理(没有显式记忆,但有强大的+1更新)...")
train_with_ollama(num_steps=200)
print("\n训练完成。运行评估...")
eval_greedy()
你构建的是没有人类的RLHF,而不是从人类反馈中进行强化学习(RLHF)
你构建的是从AI反馈中进行强化学习(RLAIF),
“一个大型LLM可以仅通过二进制奖励教一个小的神经网络——使小模型能够继承大模型的知识。”
可以把它看作是没有数据集的蒸馏。 一个由智慧教师指导的自我改进的学生。
原文链接:Reinforcement Distillation: Training a Tiny AI Using Another AI as the Reward Model
汇智网翻译整理,转载请标明出处