Data Preparasion
Import Required Packages
# Import required packages and set seed for reproducibility
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)
Download Shakespeare Dataset
# Downloading the tiny shakespeare dataset
# !wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Read Dataset
# Read the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
Create Character Mappings
# Create character to integer and vice versa mappings
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder function
decode = lambda l: ''.join([itos[i] for i in l]) # decoder function
Ancode Dataset
# Encode the entire text dataset into a tensor
data = torch.tensor(encode(text), dtype=torch.long)
Split Dataset
# Split into training and validation sets
n = int(0.9 * len(data)) # 90% for training
train_data = data[:n]
val_data = data[n:]
Set Batch Parameters
# Define batch size and block size
batch_size = 4 # Number of parallel sequences
block_size = 8 # Maximum context length for predictions
Generate Batches
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x, y
xb, yb = get_batch('train')
Attention Mechanism
Single Attention Head
# Setup random seed
torch.manual_seed(1337)
# Define dimensions
B, T, C = 4, 8, 32 # Batch, Time, Channels
x = torch.randn(B, T, C)
# Define head size
head_size = 16
# Create key, query, value projections
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# Compute attention weights
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
# Apply attention to values
v = value(x)
out = wei @ v
Multi-Head Attention
Single Head Class
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2, -1) * C**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
Multi-Head Class
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
Expert Module
class Expert(nn.Module):
def __init__(self, n_embd):
super(Expert, self).__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
Routing Mechanism
Top-k Router
class TopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(TopkRouter, self).__init__()
self.top_k = top_k
self.linear = nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
logits = self.linear(mh_output)
top_k_logits, indices = logits.topk(self.top_k, dim=-1)
zeros = torch.full_like(logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
Noisy Top-k Router
class NoisyTopkRouter(nn.Module):
def __init__(self, n_embed, num_experts, top_k):
super(NoisyTopkRouter, self).__init__()
self.top_k = top_k
self.topkroute_linear = nn.Linear(n_embed, num_experts)
self.noise_linear = nn.Linear(n_embed, num_experts)
def forward(self, mh_output):
logits = self.topkroute_linear(mh_output)
noise_logits = self.noise_linear(mh_output)
noise = torch.randn_like(logits) * F.softplus(noise_logits)
noisy_logits = logits + noise
top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
zeros = torch.full_like(noisy_logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output = F.softmax(sparse_logits, dim=-1)
return router_output, indices
Sparse Mixture of Experts
class SparseMoE(nn.Module):
def __init__(self, n_embed, num_experts, top_k, capacity_factor=1.0):
super(SparseMoE, self).__init__()
self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
self.top_k = top_k
self.capacity_factor = capacity_factor
self.num_experts = num_experts
def forward(self, x):
batch_size, seq_len, _ = x.shape
gating_output, indices = self.router(x)
final_output = torch.zeros_like(x)
flat_x = x.view(-1, x.size(-1))
flat_gating_output = gating_output.view(-1, gating_output.size(-1))
tokens_per_batch = batch_size * seq_len * self.top_k
expert_capacity = int((tokens_per_batch / self.num_experts) * self.capacity_factor)
updates = torch.zeros_like(flat_x)
for i, expert in enumerate(self.experts):
expert_mask = (indices == i).any(dim=-1)
flat_mask = expert_mask.view(-1)
selected_indices = torch.nonzero(flat_mask).squeeze(-1)
limited_indices = selected_indices[:expert_capacity] if selected_indices.numel() > expert_capacity else selected_indices
if limited_indices.numel() > 0:
expert_input = flat_x[limited_indices]
expert_output = expert(expert_input)
gating_scores = flat_gating_output[limited_indices, i].unsqueeze(1)
weighted_output = expert_output * gating_scores
updates.index_add_(0, limited_indices, weighted_output)
final_output += updates.view(batch_size, seq_len, -1)
return final_output
Model Architecture
class Block(nn.Module):
def __init__(self, n_embed, n_head, num_experts, top_k):
super().__init__()
head_size = n_embed // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.smoe = SparseMoE(n_embed, num_experts, top_k)
self.ln1 = nn.LayerNorm(n_embed)
self.ln2 = nn.LayerNorm(n_embed)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.smoe(self.ln2(x))
return x
class SparseMoELanguageModel(nn.Module):
def __init__(self, vocab_size, n_embed, n_head, num_experts, top_k, block_size, n_layer):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
self.position_embedding_table = nn.Embedding(block_size, n_embed)
self.blocks = nn.Sequential(*[Block(n_embed, n_head, num_experts, top_k) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embed)
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, loss = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
Model Initialization and Training
def kaiming_init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_normal_(m.weight)
# Initialize model
model = SparseMoELanguageModel(vocab_size, n_embed, n_head, num_experts, top_k, block_size, n_layer)
model.apply(kaiming_init_weights)
# Move to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
# Count parameters
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# Training loop
for iter in range(1000): # Example with 1000 iterations
if iter % 100 == 0:
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print(f"step {iter}: loss {loss.item():.4f}")
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Model Testing
# Generate text
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=200)[0].tolist()))