Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing GPT2 for Text Classification with MindSpore: A Comprehensive Guide

Tech May 16 2

Environment Setup

The following code sets up the necessary environment for our GPT2 implementation with MindSpore. We begin by installing the required libraries and setting up environment variables.

%%capture captured_output
# The environment is pre-configured with mindspore==2.2.14
# To change the mindspore version, modify the version number below
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# This example is compatible with mindnlp 0.3.1
# If you encounter issues, try installing mindnlp==0.3.1
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com

Data Preparation

In this section, we prepare the IMDB dataset for our text classification task. We'll load the dataset and explore its structure.

import os
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn
from mindnlp.dataset import load_dataset
from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

# Load the IMDB dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

# Check the size of the training dataset
print(f"Training dataset size: {imdb_train.get_dataset_size()}")

Dataset Processing

We define a function to process the dataset, which includes tokenization, type conversion, and batching operations.

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    # Check if running on Ascend hardware
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']
    
    if shuffle:
        dataset = dataset.shuffle(batch_size)
    
    # Apply tokenization
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    
    # Batch the dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                           'attention_mask': (None, 0)})
    return dataset

Tokenizer Configuration

We set up the GPT tokenizer with special tokens for our text classification task.

from mindnlp.transformers import GPTTokenizer

# Initialize the tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# Add special tokens
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
print(f"Added {num_added_toks} special tokens")
</pad></eos></bos>

Dattaset Splitting and Processing

We split the training data into training and validation sets, then process all datasets for model training.


# Split the training dataset into training and validation sets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

# Process the datasets
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

# Check a sample from the training dataset
sample = next(dataset_train.create_tuple_iterator())
print(f"Sample input_ids shape: {sample[0].shape}")
print(f"Sample attention_mask shape: {sample[1].shape}")
print(f"Sample label: {sample[2]}")

Model Training and Evaluation

We configure and train our GPT2 model for sequence classification, then evaluate its performance.

from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# Initialize the model for sequence classification
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

# Set up the optimizer
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()

# Define callbacks for saving checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', 
                               ckpt_name='gpt_imdb_finetune', 
                               epochs=1, 
                               keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', 
                                 ckpt_name='gpt_imdb_finetune_best', 
                                 auto_load=True)

# Initialize and run the trainer
trainer = Trainer(network=model, 
                 train_dataset=dataset_train,
                 eval_dataset=dataset_val,
                 metrics=metric,
                 epochs=1, 
                 optimizer=optimizer, 
                 callbacks=[ckpoint_cb, best_model_cb],
                 jit=False)
trainer.run(tgt_columns="labels")

# Evaluate the model
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.