ctm-dqn/dqn_agent.py

208 lines
6.5 KiB
Python

"""
Deep Q-Network (DQN) Agent for speed limit control.
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from typing import Tuple, List
class QNetwork(nn.Module):
"""Q-Network for DQN."""
def __init__(self, state_dim: int, action_dim: int, hidden_layers: List[int]):
"""
Initialize Q-Network.
Args:
state_dim: Dimension of state space
action_dim: Dimension of action space
hidden_layers: List of hidden layer sizes
"""
super(QNetwork, self).__init__()
layers = []
input_dim = state_dim
for hidden_dim in hidden_layers:
layers.append(nn.Linear(input_dim, hidden_dim))
layers.append(nn.ReLU())
input_dim = hidden_dim
layers.append(nn.Linear(input_dim, action_dim))
self.network = nn.Sequential(*layers)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
return self.network(state)
class ReplayBuffer:
"""Experience replay buffer."""
def __init__(self, capacity: int):
"""Initialize replay buffer."""
self.buffer = deque(maxlen=capacity)
def push(
self,
state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool,
):
"""Add experience to buffer."""
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size: int) -> Tuple:
"""Sample a batch of experiences."""
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states),
np.array(actions),
np.array(rewards),
np.array(next_states),
np.array(dones),
)
def __len__(self) -> int:
"""Return buffer size."""
return len(self.buffer)
class DQNAgent:
"""DQN Agent for speed limit control."""
def __init__(
self,
state_dim: int,
action_dim: int,
hidden_layers: List[int] = [128, 128],
learning_rate: float = 0.0001,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: float = 0.995,
buffer_size: int = 50000,
batch_size: int = 64,
target_update_freq: int = 10,
device: str = "cuda",
):
"""Initialize DQN agent."""
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update_freq = target_update_freq
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.q_network = QNetwork(state_dim, action_dim, hidden_layers).to(
self.device
)
self.target_network = QNetwork(state_dim, action_dim, hidden_layers).to(
self.device
)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
self.replay_buffer = ReplayBuffer(buffer_size)
self.episode_count = 0
def select_action(self, state: np.ndarray, training: bool = True) -> int:
"""Select action using epsilon-greedy policy."""
if training and random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.q_network(state_tensor)
return q_values.argmax().item()
def store_transition(
self,
state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool,
):
"""Store transition in replay buffer."""
self.replay_buffer.push(state, action, reward, next_state, done)
def train(self) -> float:
"""Train the agent on a batch of experiences."""
if len(self.replay_buffer) < self.batch_size:
return 0.0
states, actions, rewards, next_states, dones = self.replay_buffer.sample(
self.batch_size
)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
with torch.no_grad():
next_q_values = self.target_network(next_states).max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def update_target_network(self):
"""Update target network with current Q-network weights."""
self.target_network.load_state_dict(self.q_network.state_dict())
def update_epsilon(self):
"""Decay epsilon for exploration."""
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
def end_episode(self):
"""Called at the end of each episode."""
self.episode_count += 1
if self.episode_count % self.target_update_freq == 0:
self.update_target_network()
self.update_epsilon()
def save(self, path: str):
"""Save model checkpoint."""
torch.save(
{
"q_network": self.q_network.state_dict(),
"target_network": self.target_network.state_dict(),
"optimizer": self.optimizer.state_dict(),
"epsilon": self.epsilon,
"episode_count": self.episode_count,
},
path,
)
def load(self, path: str):
"""Load model checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.q_network.load_state_dict(checkpoint["q_network"])
self.target_network.load_state_dict(checkpoint["target_network"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.epsilon = checkpoint["epsilon"]
self.episode_count = checkpoint["episode_count"]