Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing Reinforcement Learning with Linear and Deep Function Approximation

Tech May 19 5

From Tabular Methods to Linear Function Approximation

In reinforcement learning, tabular methods are effective when the state and action spaces are small and discrete. However, as the complexity of the problem increases, the state space grows exponentially, leading to the 'curse of dimensionality'. Storing a value for every state-action pair becomes infeasible in terms of memory, and learning each value individually requires excessive time. Function approximation methods address this by generalizing the value of known states to estimate the value of unknown states.

The simplest form of function approximation is linear approximation. The action-value function Q(s, a) is approximated by a linear combination of feature vectors:

Q(s, a) ≈ φ(s, a)Tθ

Here, φ(s, a) represents the feature vector corresponding to the state-action pair, and θ is the parameter vector (weights) to be learned.

Feature Representation Strategies

Tabular Features: This is a special case of function approximation where the feature vector is a one-hot encoded vector. If the state space has size |S| and the action space has size |A|, the feature vector has a dimension of |S| × |A|. For a specific state-action pair (s, a), only the element at the corresponding index is 1, and all others are 0. While this allows for exact value estimation, it does not resolve the memory issue for large state spaces.

Fixed Sparse Representation (FSR): To reduce the dimensionality, FSR maps states to features more efficiently. For a grid-world environment of size n × n, instead of using n² features, one can use 2n features (n for the x-coordinate and n for the y-coordinate). The feature vector for a state-action pair activates specific indices based on the coordinates, effectively reducing the parameter count from exponential to linear relative to the grid dimensions.

Parameter Optimization

The goal is to find the parameter vector θ that minimizes the mean squared error between the approximated value and the target value. The loss function is defined as:

L(θ) = E[(Qtarget - φ(s, a)Tθ)²]

Using gradient descent, the update rule for the parameters becomes:

θ ← θ + α[r + γ maxa' φ(s', a')Tθ - φ(s, a)Tθ] φ(s, a)

This approach effectively applies Q-learning updates to the weight vector rather than a table entry.

Implementing Linear Q-Learning

The following Python implementation demonstrates a Q-learning agent using linear function approximation. We define a class LinearQLearner that supports both tabular features and fixed sparse representations.

import numpy as np
import random

class LinearQLearner:
    def __init__(self, env, feature_type='sparse'):
        self.env = env
        self.gamma = 0.9
        self.feature_type = feature_type
        
        # Determine feature dimension based on representation type
        if feature_type == 'tabular':
            self.feature_dim = env.state_num * env.action_num
        else: # Fixed Sparse Representation
            # Example: 10 (x) + 10 (y) per action = 20 * 4 actions = 80
            self.feature_dim = (env.grid_size + env.grid_size) * env.action_num
            
        self.weights = np.zeros((self.feature_dim, 1))

    def get_features(self, state, action):
        """Constructs the feature vector phi(s, a)."""
        phi = np.zeros((self.feature_dim, 1))
        
        if self.feature_type == 'tabular':
            # One-hot encoding for state-action pair
            index = action * self.env.state_num + state
            phi[index] = 1
        else:
            # Fixed Sparse Representation
            # Assuming state is an index, convert to coordinates
            y = int(state / self.env.grid_size)
            x = state % self.env.grid_size
            
            # Activate features for x and y coordinates for the given action
            base_idx = action * (self.env.grid_size * 2)
            phi[base_idx + x] = 1
            phi[base_idx + self.env.grid_size + y] = 1
            
        return phi

    def get_q_value(self, state, action):
        phi = self.get_features(state, action)
        return np.dot(phi.T, self.weights)[0, 0]

    def greedy_policy(self, state):
        q_values = [self.get_q_value(state, a) for a in range(self.env.action_num)]
        return np.argmax(q_values)

    def epsilon_greedy_action(self, state, epsilon):
        if random.random() < 1 - epsilon:
            return self.greedy_policy(state)
        return random.randint(0, self.env.action_num - 1)

    def train(self, num_episodes, alpha, epsilon):
        for episode in range(num_episodes):
            state = 0 # Initial state
            done = False
            steps = 0
            
            while not done and steps < 100:
                action = self.epsilon_greedy_action(state, epsilon)
                next_state, reward, done = self.env.step(state, action)
                
                # Calculate TD Target
                phi_current = self.get_features(state, action)
                q_current = np.dot(phi_current.T, self.weights)[0, 0]
                
                if done:
                    td_target = reward
                else:
                    # Max Q value for next state
                    next_action = self.greedy_policy(next_state)
                    q_next = self.get_q_value(next_state, next_action)
                    td_target = reward + self.gamma * q_next
                
                # Gradient Descent Update
                error = td_target - q_current
                self.weights += alpha * error * phi_current
                
                state = next_state
                steps += 1
        return self.weights

Nonlinear Approximation with Deep Q-Networks (DQN)

While linear function approximation is stable and mathematically tractable, its representational capacity is limited. Complex environments often require nonlinear function approximation to capture intricate patterns. Deep Q-Networks (DQN) utilize Convolutional Neural Networks (CNNs) to approximate the Q-function directly from raw pixel inputs.

DQN introduces two critical innovations to stabilize training:

  1. Experience Replay: Transitions (s, a, r, s') are stored in a replay memory. During training, mini-batches are sampled randomly from this memory. This breaks the correlation between consecutive samples and improves data efficiency.
  2. Target Network: A separate target network is used to calculate the TD target values. This network is updated less frequently (or via soft updates) than the main training network, preventing the optimization process from chasing a moving target.

DQN Implementation with PyTorch

The following code implements a DQN agent using PyTorch. The architecture consists of convolutional layers to process image input and fully connected layers to output Q-values for each action.

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

class ConvQNetwork(nn.Module):
    def __init__(self, input_channels, num_actions):
        super(ConvQNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        # Calculate the output size of conv layers dynamically or assume flattened size
        # Example assumes input 80x80x4, resulting in 64x7x7 = 3136 features
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        # Normalize pixel values
        x = x.float() / 255.0
        conv_out = self.conv(x)
        # Flatten for fully connected layers
        flat_out = conv_out.view(x.size(0), -1)
        return self.fc(flat_out)

class DQNAgent:
    def __init__(self, input_shape, num_actions, lr=1e-6):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_actions = num_actions
        
        # Evaluation Network
        self.q_net = ConvQNetwork(input_shape[0], num_actions).to(self.device)
        # Target Network
        self.target_net = ConvQNetwork(input_shape[0], num_actions).to(self.device)
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.memory = ReplayBuffer(50000)
        self.batch_size = 32
        self.gamma = 0.99
        self.tau = 0.01 # Soft update parameter

    def select_action(self, state, epsilon):
        # Epsilon-greedy strategy
        if random.random() > epsilon:
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.q_net(state_tensor)
            return q_values.argmax(1).item()
        return random.randrange(self.num_actions)

    def update_model(self):
        if len(self.memory) < self.batch_size:
            return
            
        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        # Current Q values
        q_values = self.q_net(states).gather(1, actions.unsqueeze(1))
        
        # Target Q values
        with torch.no_grad():
            next_q_values = self.target_net(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # Loss and Backpropagation
        loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Soft update of target network
        for target_param, local_param in zip(self.target_net.parameters(), self.q_net.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

    def train(self, env, max_steps):
        epsilon = 1.0
        state = env.reset()
        
        for step in range(max_steps):
            action = self.select_action(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            
            self.memory.push(state, action, reward, next_state, done)
            state = next_state
            
            self.update_model()
            
            if done:
                state = env.reset()
            
            # Epsilon decay
            epsilon = max(0.1, epsilon - 1e-5)

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.