用Unsloth微调Gemma 4 E2B

在有限硬件上使用Unsloth微调Google最新多模态模型的实用、无废话指南。

用Unsloth微调Gemma 4 E2B
微信 ezpoda免费咨询:AI编程 | AI模型微调| AI私有化部署
AI工具导航 | Tripo 3D | Meshy AI | ElevenLabs | KlingAI | ArtSpace | Phot.AI | InVideo

在Gemma 4发布仅几天后,我想测试在免费的Kaggle GPU上能将它推到多远。前景令人兴奋:一个现代多模态模型,能够进行推理、编码,甚至目标检测——全部以开放权重提供。

然而,现实并不那么友好。

我首先尝试使用transformers库,确实修改了我之前微调Gemma 2或Llama 3等模型进行情感分析任务的Notebook,使用transformers、trl和SFTTrainer。内存不足错误、损坏的量化路径和不兼容的PEFT层让我清楚地认识到,微调Gemma 4并不像早期模型那样简单(至少对我来说)。

本文将介绍什么有效、什么失败,以及我最终如何成功使用Unsloth微调Gemma 4 E2B。

1、问题所在

当我首次使用transformers框架进行实验时,推理进行得很顺利。切换到微调时,由于几个因素我遇到了OOM问题。微调比推理需要更多的内存,因为必须存储梯度和优化器状态。多模态塔增加了所需的内存。即使batch size为1也需要大量内存。此外,我还遇到了PEFT不兼容问题,尚未完全适配。Gemma 4使用自定义层。Gemma4ClippableLinear不受支持。LoRA注入失败。在Kaggle上,bitsandbytestransformers版本不匹配。

为什么Unsloth有效?

Unsloth优化了内存和训练。它支持Gemma模型并为发布做好了准备,文档中也支持微调。它更快(快达60%)且使用更少的显存。

2、可行的设置

2.1 初始化模型

我首先在Kaggle T4 x 2环境中安装最新的unsloth库:

!pip install -U -q unsloth

使用unsloth和进行微调所需的库:

from unsloth import FastLanguageModel
import torch
from datasets import load_dataset, Dataset
from trl import SFTTrainer, SFTConfig

运行上述单元格后:

🦥 Unsloth: 将修补你的计算机以启用2倍更快的免费微调。
🦥 Unsloth Zoo 现在将修补所有内容以使训练更快!

接下来,我使用unslothFastLanguageModel初始化模型:

max_seq_length = 512  # 从小开始;成功后再扩展

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "google/gemma-4-E2b-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False,  # MoE QLoRA不推荐,密集27B可以
    load_in_16bit = True,  # bf16/16-bit LoRA
    full_finetuning = False,
)

运行上述单元格后:

==((====))== Unsloth 2026.4.2: 快速Gemma4修补。Transformers: 5.5.0。
\\ /| Tesla T4. GPU数量 = 2. 最大内存: 14.563 GB. 平台: Linux.
O^O/ \_/ \ Torch: 2.10.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.6.0
\ / Bfloat16 = FALSE. FA [Xformers = 0.0.35. FA2 = False]
"-____-" 免费许可证: http://github.com/unslothai/unsloth
Unsloth: 快速下载已启用 - 忽略红色的下载进度条!
Unsloth: 对gemma4使用float16精度不起作用!使用float32。

2.2 使用批量推理测试模型

如果我们想先在微调前测试模型(用于情感分析数据集上的推理),我们将模型设置为推理模式:

FastLanguageModel.for_inference(model)

对于相当大数据集上的推理任务,我们更倾向于运行批量推理。下面展示了使用Unsloth初始化的模型为我们的情感分析任务运行批量推理的过程:

def predict_batch(df, model, tokenizer, batch_size=8, max_new_tokens=5):
    y_pred = []
    texts = df["text"].tolist()
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i + batch_size]
        
        messages_batch = [
            [{"role": "user", "content": text}]
            for text in batch
        ]
        
        prompts = [
            tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
            for messages in messages_batch
        ]
        
        inputs = tokenizer(
            text=prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        ).to(model.device)
        
        input_len = inputs["input_ids"].shape[1]
        
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )
        
        decoded = tokenizer.batch_decode(
            outputs[:, input_len:],
            skip_special_tokens=True
        )
        
        for out in decoded:
            out = out.strip().lower()
            
            if "positive" in out:
                y_pred.append("positive")
            elif "negative" in out:
                y_pred.append("negative")
            elif "neutral" in out:
                y_pred.append("neutral")
            else:
                y_pred.append("none")
    
    return y_pred

2.3 准备模型进行微调

我们使用unslothFastLanguageModel.get_peft_model准备模型进行参数高效微调。

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    # "unsloth"检查点用于非常长的上下文 + 较低的显存
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    max_seq_length = max_seq_length,
)

然后我们使用SFTTrainer初始化trainer。这里我们设置:

  • per_device_train_batch_size = 1(Gemma 4和可用内存的标准值)
  • gradient_accumulation_steps = 4(非常适合我们的任务)
  • num_training_epochs = 5(这意味着,根据我们训练集的大小和batch size,总共1125步)
trainer = SFTTrainer(
    model = model,
    train_dataset = train_data,
    tokenizer = tokenizer,
    args = SFTConfig(
        max_seq_length = max_seq_length,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        num_train_epochs = 5,
        logging_steps = 25,
        output_dir = "outputs_gemma4_E2B",
        optim = "adamw_8bit",
        seed = 3407,
        dataset_num_proc = 1,
    ),
)

运行后,保存了3个检查点并存在于Notebook输出中(checkpoint-500checkpoint-1000checkpoint-1125)。这些不包含所有模型权重,而只有LoRA权重。要使用微调后的模型,我们可以保存这些检查点并初始化一个类似的模型,然后只需添加LoRA权重。

from unsloth import FastLanguageModel

root_path = "/kaggle/input/notebooks/gpreda/fine-tune-gemma-4-e2b-with-unsloth/"
checkpoint_path = root_path + "outputs_gemma4_E2B/checkpoint-1125"
max_seq_length = 512

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=checkpoint_path,  # 直接指向适配器检查点文件夹
    max_seq_length=max_seq_length,
    load_in_4bit=False,
    load_in_16bit=True,
    full_finetuning=False,
)

FastLanguageModel.for_inference(model)

3、结束语

从我首次在低端GPU计算资源上微调Gemma 4的实验来看,目前还没有简单的方法。OOM问题、不兼容问题让这变得有点困难。根据我的经验,Unsloth现在提供了最实用的路径,但其他工具也会很快跟上。

Gemma 4代表了开放权重模型的重大进步,但今天微调它需要仔细选择工具,并愿意调试底层问题。

Unsloth被证明是目前弥合这一差距的最可靠方式,即使在Kaggle等受限环境中也能实现训练。

虽然生态系统仍在发展,但方向很明确:强大的本地微调正变得可及——只是还不轻松。


原文链接: From OOM Errors to Working Model: Fine-Tuning Gemma 4 E2B Step-by-Step using Unsloth

汇智网翻译整理,转载请标明出处