Fine-Tuning Gemma Models with Hugging Face and Parameter-Efficient Methods
Google DeepMind's Gemma language models are available with open weights on Hugging Face. The family includes 2B and 7B parameter variants, offered in both pre-trained and instruction-tuned versions. These models are supported on the Hugging Face platform and can be deployed and fine-tuned on services like Vertex Model Garden and Google Kubernetes Engine.
Gemma models are well-suited for prototyping and experimentation using free GPU resources such as those provided by Colab. This guide demonstrates parameter-efficient fine-tuning (PEFT) of Gemma models on GPU and Cloud TPU hardware using the Hugging Face Transformers and PEFT libraries.
The Case for Parameter-Efficient Fine-Tuning
Full-parameter fine-tuning of language models demands significant memory and computational resources, which can be prohibitive on public platforms like Colab or Kaggle. For enterprise use, the cost of adapting models to various domains is a key optimization target. PEFT provides a method to achieve adaptation at a lower cost.
Efficient Fine-Tuning of Gemma with PyTorch on GPU and TPU
The Gemma implementatino in Hugging Face transformers is optimized for both PyTorch and PyTorch/XLA, providing access for TPU and GPU users. The release includes improvements for using Fully Sharded Data Parallel (FSDP) with PyTorch/XLA via SPMD, which also benefits other Hugging Face models. This article focuses on PEFT for Gemma, specifically using Low-Rank Adaptation (LoRA).
Fine-Tuning Large Language Models with Low-Rank Adaptation (LoRA)
LoRA is a PEFT technique that fine-tunes only a small subset of a model's parameters. It works by freezing the original model and training adapter layers decomposed into low-rank matrices. The PEFT library offers a simple abstraction for selecting which model layers to apply these adapter weights to.
from peft import LoraConfig
lora_settings = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
This configuration targets all nn.Linear layers for adaptation.
The following example uses QLoRA, a method that employs 4-bit quantization of the base model for greater memory efficiency during fine-tuning. This requires the bitsandbytes library and passing a BitsAndBytesConfig when loading the model.
Prerequisites
Access to Gemma model files requires submitting a consent form on the Hugging Face website.
Fine-Tuning Gemma to Generate Quotes in a Specific Format
Assuming access is granted, you can fetch the model from the Hugging Face Hub.
First, load the tokenizer and model with 4-bit quantization configured.
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id = "google/gemma-2b"
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config, device_map={"":0}, token=os.environ['HF_TOKEN'])
Test the base model with a familiar quote prompt.
prompt_text = "Quote: Imagination is more"
device = "cuda:0"
model_inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
model_outputs = model.generate(**model_inputs, max_new_tokens=20)
print(tokenizer.decode(model_outputs[0], skip_special_tokens=True))
The base model may produce a reasonable completion but not in a desired structured format. The goal is to fine-tune it to output quotes followed by the author on a new line.
Load a dataset of English quotes for training.
from datasets import load_dataset
dataset = load_dataset("Abirate/english_quotes")
dataset = dataset.map(lambda samples: tokenizer(samples["quote"]), batched=True)
Proceed with fine-tuning using the LoRA configuration.
import transformers
from trl import SFTTrainer
def format_training_example(example):
output_text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
return [output_text]
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=10,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_settings,
formatting_func=format_training_example,
)
trainer.train()
Test the fine-tuned model with the same prompt.
prompt_text = "Quote: Imagination is"
model_inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
model_outputs = model.generate(**model_inputs, max_new_tokens=20)
print(tokenizer.decode(model_outputs[0], skip_special_tokens=True))
The output should now match the desired format.
Accelerated Fine-Tuning on TPU with FSDP over SPMD
Hugging Face transformers supports the latest FSDP implementation in PyTorch/XLA for faster training on TPUs. Enable this feature by adding an FSDP configuration to the transformers.Trainer.
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
fsdp_configuration = {
"fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True
}
trainer = Trainer(
model=model,
train_dataset=dataset,
args=TrainingArguments(
per_device_train_batch_size=64,
num_train_epochs=100,
max_steps=-1,
output_dir="./output",
optim="adafactor",
logging_steps=1,
dataloader_drop_last=True,
fsdp="full_shard",
fsdp_config=fsdp_configuration,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()