Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Fine-Tuning Gemma Models with Hugging Face and Parameter-Efficient Methods

Tech May 9 3

Google DeepMind's Gemma language models are available with open weights on Hugging Face. The family includes 2B and 7B parameter variants, offered in both pre-trained and instruction-tuned versions. These models are supported on the Hugging Face platform and can be deployed and fine-tuned on services like Vertex Model Garden and Google Kubernetes Engine.

Gemma models are well-suited for prototyping and experimentation using free GPU resources such as those provided by Colab. This guide demonstrates parameter-efficient fine-tuning (PEFT) of Gemma models on GPU and Cloud TPU hardware using the Hugging Face Transformers and PEFT libraries.

The Case for Parameter-Efficient Fine-Tuning

Full-parameter fine-tuning of language models demands significant memory and computational resources, which can be prohibitive on public platforms like Colab or Kaggle. For enterprise use, the cost of adapting models to various domains is a key optimization target. PEFT provides a method to achieve adaptation at a lower cost.

Efficient Fine-Tuning of Gemma with PyTorch on GPU and TPU

The Gemma implementatino in Hugging Face transformers is optimized for both PyTorch and PyTorch/XLA, providing access for TPU and GPU users. The release includes improvements for using Fully Sharded Data Parallel (FSDP) with PyTorch/XLA via SPMD, which also benefits other Hugging Face models. This article focuses on PEFT for Gemma, specifically using Low-Rank Adaptation (LoRA).

Fine-Tuning Large Language Models with Low-Rank Adaptation (LoRA)

LoRA is a PEFT technique that fine-tunes only a small subset of a model's parameters. It works by freezing the original model and training adapter layers decomposed into low-rank matrices. The PEFT library offers a simple abstraction for selecting which model layers to apply these adapter weights to.

from peft import LoraConfig

lora_settings = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

This configuration targets all nn.Linear layers for adaptation.

The following example uses QLoRA, a method that employs 4-bit quantization of the base model for greater memory efficiency during fine-tuning. This requires the bitsandbytes library and passing a BitsAndBytesConfig when loading the model.

Prerequisites

Access to Gemma model files requires submitting a consent form on the Hugging Face website.

Fine-Tuning Gemma to Generate Quotes in a Specific Format

Assuming access is granted, you can fetch the model from the Hugging Face Hub.

First, load the tokenizer and model with 4-bit quantization configured.

import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/gemma-2b"
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config, device_map={"":0}, token=os.environ['HF_TOKEN'])

Test the base model with a familiar quote prompt.

prompt_text = "Quote: Imagination is more"
device = "cuda:0"
model_inputs = tokenizer(prompt_text, return_tensors="pt").to(device)

model_outputs = model.generate(**model_inputs, max_new_tokens=20)
print(tokenizer.decode(model_outputs[0], skip_special_tokens=True))

The base model may produce a reasonable completion but not in a desired structured format. The goal is to fine-tune it to output quotes followed by the author on a new line.

Load a dataset of English quotes for training.

from datasets import load_dataset

dataset = load_dataset("Abirate/english_quotes")
dataset = dataset.map(lambda samples: tokenizer(samples["quote"]), batched=True)

Proceed with fine-tuning using the LoRA configuration.

import transformers
from trl import SFTTrainer

def format_training_example(example):
    output_text = f"Quote: {example['quote'][0]}\nAuthor: {example['author'][0]}"
    return [output_text]

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=10,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_settings,
    formatting_func=format_training_example,
)
trainer.train()

Test the fine-tuned model with the same prompt.

prompt_text = "Quote: Imagination is"
model_inputs = tokenizer(prompt_text, return_tensors="pt").to(device)

model_outputs = model.generate(**model_inputs, max_new_tokens=20)
print(tokenizer.decode(model_outputs[0], skip_special_tokens=True))

The output should now match the desired format.

Accelerated Fine-Tuning on TPU with FSDP over SPMD

Hugging Face transformers supports the latest FSDP implementation in PyTorch/XLA for faster training on TPUs. Enable this feature by adding an FSDP configuration to the transformers.Trainer.

from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

fsdp_configuration = {
    "fsdp_transformer_layer_cls_to_wrap": ["GemmaDecoderLayer"],
    "xla": True,
    "xla_fsdp_v2": True,
    "xla_fsdp_grad_ckpt": True
}

trainer = Trainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=64,
        num_train_epochs=100,
        max_steps=-1,
        output_dir="./output",
        optim="adafactor",
        logging_steps=1,
        dataloader_drop_last=True,
        fsdp="full_shard",
        fsdp_config=fsdp_configuration,
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.