Part 20 - Implementing Deep Q-Networks (DQN) in Python
Machine Learning Algorithms Series - Reinforcement Learning with PyTorch
This article explains how to implement Deep Q-Networks (DQN), a reinforcement learning algorithm, in Python using PyTorch. DQN combines Q-learning with deep neural networks, using a neural network to approximate the Q-values for each action in a given state. This allows it to handle environments with high-dimensional and continuous state spaces. DQN uses experience replay (storing past experiences and training on random batches) and a target network to stabilize training. The article covers importing necessary libraries, defining the DQN model, setting hyperparameters, initializing the environment and DQN model, implementing experience replay, and running the training loop.
Step-by-Step Implementation
Importing Libraries:
Import
gym
for the environment interface.Import
torch
,torch.nn
(asnn
), andtorch.optim
(asoptim
).Import
numpy
for numerical operations.Import
random
for sampling random actions and batches.
Defining the DQN Model:
Create a class
DQN
that inherits fromnn.Module
.Initialize the DQN model with three fully connected linear layers:
The first layer takes the input size (number of state features) and outputs 64 features.
The second layer processes 64 features and outputs 32 features.
The final layer outputs the output size values (one for each possible action).
Create a
forward
method that defines the forward pass of the model. Apply ReLU activation to the first and second linear layers, and no activation to the final layer.
Setting Hyperparameters:
Define hyperparameters such as:
env_name
: Environment name (e.g., "CartPole-v1").learning_rate
: Learning rate for the Adam optimizer (e.g., 0.001).gamma
: Discount factor for future rewards.buffer_size
: Maximum number of experiences to store in the replay buffer.batch_size
: Number of experiences sampled from the replay buffer per training step.epsilon
: Exploration rate for epsilon-greedy action selection (e.g., 0.1).target_update_frequency
: Number of steps between target network updates.
Initializing the Environment and DQN Model:
Create the Gym environment using
gym.make(env_name)
.Set the input size for the DQN model based on the state dimension:
env.observation_space.shape
.Set the output size for the DQN model based on the number of possible actions:
env.action_space.n
.Determine the computation device (GPU if available, otherwise CPU).
Initialize the policy net (main DQN model for action selection) and the target net (secondary DQN model for computing target Q-values).
Load the policy net's weights into the target net:
target_net.load_state_dict(policy_net.state_dict())
.Set the target net to evaluation mode.
Create the Adam optimizer for updating the policy net's weights.
Define the mean squared error (MSE) loss criterion for training.
Implementing Experience Replay:
Create an empty replay buffer to store past experiences.
Experiences are stored as tuples containing state, action, reward, next state, and done.
Running the Training Function:
Implement the training loop that trains the agent over a specified number of episodes.
Reset the environment at the start of each episode.
Use an epsilon-greedy policy for action selection:
With probability epsilon, choose a random action.
Otherwise, choose the action with the highest predicted Q-value using the policy net.
Take the selected action and observe the reward and next state.
Store the experience in the replay buffer.
If the replay buffer is larger than the buffer size, remove the oldest experience.
Sample a mini-batch from the replay buffer if it contains enough samples.
Convert the batch components (states, actions, rewards, next states, dones) to PyTorch tensors.
Compute the current Q-values and target Q-values:
Compute the Q-values for actions taken using the policy net.
Compute the Q-values of the next state using the target net.
Calculate the target Q-values based on rewards and future Q-values, accounting for episode termination.
Compute the loss and update the policy network:
Calculate the MSE loss between the current Q-values and the target Q-values.
Clear previous gradients.
Compute the gradients.
Update the model parameters using the optimizer.
Periodically update the target network with the policy net's weights for stability.
Print the episode number and total reward.
Complete Code Example
# Import necessary libraries
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# Define the DQN model
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.linear1 = nn.Linear(input_size, 64)
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, output_size)
def forward(self, x):
x = torch.relu(self.linear1(x))
x = torch.relu(self.linear2(x))
x = self.linear3(x)
return x
# Define hyperparameters
env_name = "CartPole-v1"
learning_rate = 0.001
gamma = 0.99
buffer_size = 10000
batch_size = 32
epsilon = 0.1
target_update_frequency = 100
# Initialize environment and DQN model
env = gym.make(env_name)
input_size = env.observation_space.shape
output_size = env.action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = DQN(input_size, output_size).to(device)
target_net = DQN(input_size, output_size).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# Create replay buffer
replay_buffer = []
# Training function
def train(num_episodes):
step_count = 0
for episode in range(num_episodes):
state, _ = env.reset()
state = np.array(state)
done = False
total_reward = 0
while not done:
if random.random() < epsilon:
action = env.action_space.sample()
else:
state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
with torch.no_grad():
q_values = policy_net(state_tensor)
action = torch.argmax(q_values).item()
next_state, reward, done, truncated, _ = env.step(action)
done = done or truncated
next_state = np.array(next_state)
total_reward += reward
replay_buffer.append((state, action, reward, next_state, done))
if len(replay_buffer) > buffer_size:
replay_buffer.pop(0)
state = next_state
if len(replay_buffer) >= batch_size:
batch = random.sample(replay_buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.tensor(states, dtype=torch.float32, device=device)
actions = torch.tensor(actions, dtype=torch.int64, device=device).unsqueeze(1)
rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
next_states = torch.tensor(next_states, dtype=torch.float32, device=device)
dones = torch.tensor(dones, dtype=torch.float32, device=device)
current_q_values = policy_net(states).gather(1, actions)
next_q_values = target_net(next_states).max(1).detach()
target_q_values = rewards + gamma * next_q_values * (1 - dones)
loss = criterion(current_q_values, target_q_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
step_count += 1
if step_count % target_update_frequency == 0:
target_net.load_state_dict(policy_net.state_dict())
print(f"Episode {episode}, Total Reward: {total_reward}")
# Run the training
train(num_episodes=2000)
Conclusion
This article demonstrates how to implement the DQN algorithm using Python and PyTorch. By combining deep neural networks with Q-learning and employing techniques such as experience replay and target networks, DQN can effectively learn to make decisions in complex environments. This approach is a foundational method in reinforcement learning and can be extended to more sophisticated algorithms and applications.