Implementing Reinforcement Learning with Linear and Deep Function Approximation
From Tabular Methods to Linear Function Approximation
In reinforcement learning, tabular methods are effective when the state and action spaces are small and discrete. However, as the complexity of the problem increases, the state space grows exponentially, leading to the 'curse of dimensionality'. Storing a value for every state-action pair becomes infeasible in terms of memory, and learning each value individually requires excessive time. Function approximation methods address this by generalizing the value of known states to estimate the value of unknown states.
The simplest form of function approximation is linear approximation. The action-value function Q(s, a) is approximated by a linear combination of feature vectors:
Q(s, a) ≈ φ(s, a)Tθ
Here, φ(s, a) represents the feature vector corresponding to the state-action pair, and θ is the parameter vector (weights) to be learned.
Feature Representation Strategies
Tabular Features: This is a special case of function approximation where the feature vector is a one-hot encoded vector. If the state space has size |S| and the action space has size |A|, the feature vector has a dimension of |S| × |A|. For a specific state-action pair (s, a), only the element at the corresponding index is 1, and all others are 0. While this allows for exact value estimation, it does not resolve the memory issue for large state spaces.
Fixed Sparse Representation (FSR): To reduce the dimensionality, FSR maps states to features more efficiently. For a grid-world environment of size n × n, instead of using n² features, one can use 2n features (n for the x-coordinate and n for the y-coordinate). The feature vector for a state-action pair activates specific indices based on the coordinates, effectively reducing the parameter count from exponential to linear relative to the grid dimensions.
Parameter Optimization
The goal is to find the parameter vector θ that minimizes the mean squared error between the approximated value and the target value. The loss function is defined as:
L(θ) = E[(Qtarget - φ(s, a)Tθ)²]
Using gradient descent, the update rule for the parameters becomes:
θ ← θ + α[r + γ maxa' φ(s', a')Tθ - φ(s, a)Tθ] φ(s, a)
This approach effectively applies Q-learning updates to the weight vector rather than a table entry.
Implementing Linear Q-Learning
The following Python implementation demonstrates a Q-learning agent using linear function approximation. We define a class LinearQLearner that supports both tabular features and fixed sparse representations.
import numpy as np
import random
class LinearQLearner:
def __init__(self, env, feature_type='sparse'):
self.env = env
self.gamma = 0.9
self.feature_type = feature_type
# Determine feature dimension based on representation type
if feature_type == 'tabular':
self.feature_dim = env.state_num * env.action_num
else: # Fixed Sparse Representation
# Example: 10 (x) + 10 (y) per action = 20 * 4 actions = 80
self.feature_dim = (env.grid_size + env.grid_size) * env.action_num
self.weights = np.zeros((self.feature_dim, 1))
def get_features(self, state, action):
"""Constructs the feature vector phi(s, a)."""
phi = np.zeros((self.feature_dim, 1))
if self.feature_type == 'tabular':
# One-hot encoding for state-action pair
index = action * self.env.state_num + state
phi[index] = 1
else:
# Fixed Sparse Representation
# Assuming state is an index, convert to coordinates
y = int(state / self.env.grid_size)
x = state % self.env.grid_size
# Activate features for x and y coordinates for the given action
base_idx = action * (self.env.grid_size * 2)
phi[base_idx + x] = 1
phi[base_idx + self.env.grid_size + y] = 1
return phi
def get_q_value(self, state, action):
phi = self.get_features(state, action)
return np.dot(phi.T, self.weights)[0, 0]
def greedy_policy(self, state):
q_values = [self.get_q_value(state, a) for a in range(self.env.action_num)]
return np.argmax(q_values)
def epsilon_greedy_action(self, state, epsilon):
if random.random() < 1 - epsilon:
return self.greedy_policy(state)
return random.randint(0, self.env.action_num - 1)
def train(self, num_episodes, alpha, epsilon):
for episode in range(num_episodes):
state = 0 # Initial state
done = False
steps = 0
while not done and steps < 100:
action = self.epsilon_greedy_action(state, epsilon)
next_state, reward, done = self.env.step(state, action)
# Calculate TD Target
phi_current = self.get_features(state, action)
q_current = np.dot(phi_current.T, self.weights)[0, 0]
if done:
td_target = reward
else:
# Max Q value for next state
next_action = self.greedy_policy(next_state)
q_next = self.get_q_value(next_state, next_action)
td_target = reward + self.gamma * q_next
# Gradient Descent Update
error = td_target - q_current
self.weights += alpha * error * phi_current
state = next_state
steps += 1
return self.weightsNonlinear Approximation with Deep Q-Networks (DQN)
While linear function approximation is stable and mathematically tractable, its representational capacity is limited. Complex environments often require nonlinear function approximation to capture intricate patterns. Deep Q-Networks (DQN) utilize Convolutional Neural Networks (CNNs) to approximate the Q-function directly from raw pixel inputs.
DQN introduces two critical innovations to stabilize training:
- Experience Replay: Transitions (s, a, r, s') are stored in a replay memory. During training, mini-batches are sampled randomly from this memory. This breaks the correlation between consecutive samples and improves data efficiency.
- Target Network: A separate target network is used to calculate the TD target values. This network is updated less frequently (or via soft updates) than the main training network, preventing the optimization process from chasing a moving target.
DQN Implementation with PyTorch
The following code implements a DQN agent using PyTorch. The architecture consists of convolutional layers to process image input and fully connected layers to output Q-values for each action.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
class ConvQNetwork(nn.Module):
def __init__(self, input_channels, num_actions):
super(ConvQNetwork, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
# Calculate the output size of conv layers dynamically or assume flattened size
# Example assumes input 80x80x4, resulting in 64x7x7 = 3136 features
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, num_actions)
)
def forward(self, x):
# Normalize pixel values
x = x.float() / 255.0
conv_out = self.conv(x)
# Flatten for fully connected layers
flat_out = conv_out.view(x.size(0), -1)
return self.fc(flat_out)
class DQNAgent:
def __init__(self, input_shape, num_actions, lr=1e-6):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_actions = num_actions
# Evaluation Network
self.q_net = ConvQNetwork(input_shape[0], num_actions).to(self.device)
# Target Network
self.target_net = ConvQNetwork(input_shape[0], num_actions).to(self.device)
self.target_net.load_state_dict(self.q_net.state_dict())
self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
self.memory = ReplayBuffer(50000)
self.batch_size = 32
self.gamma = 0.99
self.tau = 0.01 # Soft update parameter
def select_action(self, state, epsilon):
# Epsilon-greedy strategy
if random.random() > epsilon:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.q_net(state_tensor)
return q_values.argmax(1).item()
return random.randrange(self.num_actions)
def update_model(self):
if len(self.memory) < self.batch_size:
return
states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Current Q values
q_values = self.q_net(states).gather(1, actions.unsqueeze(1))
# Target Q values
with torch.no_grad():
next_q_values = self.target_net(next_states).max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Loss and Backpropagation
loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Soft update of target network
for target_param, local_param in zip(self.target_net.parameters(), self.q_net.parameters()):
target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
def train(self, env, max_steps):
epsilon = 1.0
state = env.reset()
for step in range(max_steps):
action = self.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
self.memory.push(state, action, reward, next_state, done)
state = next_state
self.update_model()
if done:
state = env.reset()
# Epsilon decay
epsilon = max(0.1, epsilon - 1e-5)