208 lines
6.5 KiB
Python
208 lines
6.5 KiB
Python
"""
|
|
Deep Q-Network (DQN) Agent for speed limit control.
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
from typing import Tuple, List
|
|
|
|
|
|
class QNetwork(nn.Module):
|
|
"""Q-Network for DQN."""
|
|
|
|
def __init__(self, state_dim: int, action_dim: int, hidden_layers: List[int]):
|
|
"""
|
|
Initialize Q-Network.
|
|
|
|
Args:
|
|
state_dim: Dimension of state space
|
|
action_dim: Dimension of action space
|
|
hidden_layers: List of hidden layer sizes
|
|
"""
|
|
super(QNetwork, self).__init__()
|
|
|
|
layers = []
|
|
input_dim = state_dim
|
|
|
|
for hidden_dim in hidden_layers:
|
|
layers.append(nn.Linear(input_dim, hidden_dim))
|
|
layers.append(nn.ReLU())
|
|
input_dim = hidden_dim
|
|
|
|
layers.append(nn.Linear(input_dim, action_dim))
|
|
|
|
self.network = nn.Sequential(*layers)
|
|
|
|
def forward(self, state: torch.Tensor) -> torch.Tensor:
|
|
"""Forward pass."""
|
|
return self.network(state)
|
|
|
|
|
|
class ReplayBuffer:
|
|
"""Experience replay buffer."""
|
|
|
|
def __init__(self, capacity: int):
|
|
"""Initialize replay buffer."""
|
|
self.buffer = deque(maxlen=capacity)
|
|
|
|
def push(
|
|
self,
|
|
state: np.ndarray,
|
|
action: int,
|
|
reward: float,
|
|
next_state: np.ndarray,
|
|
done: bool,
|
|
):
|
|
"""Add experience to buffer."""
|
|
self.buffer.append((state, action, reward, next_state, done))
|
|
|
|
def sample(self, batch_size: int) -> Tuple:
|
|
"""Sample a batch of experiences."""
|
|
batch = random.sample(self.buffer, batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
return (
|
|
np.array(states),
|
|
np.array(actions),
|
|
np.array(rewards),
|
|
np.array(next_states),
|
|
np.array(dones),
|
|
)
|
|
|
|
def __len__(self) -> int:
|
|
"""Return buffer size."""
|
|
return len(self.buffer)
|
|
|
|
|
|
class DQNAgent:
|
|
"""DQN Agent for speed limit control."""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
action_dim: int,
|
|
hidden_layers: List[int] = [128, 128],
|
|
learning_rate: float = 0.0001,
|
|
gamma: float = 0.99,
|
|
epsilon_start: float = 1.0,
|
|
epsilon_end: float = 0.01,
|
|
epsilon_decay: float = 0.995,
|
|
buffer_size: int = 50000,
|
|
batch_size: int = 64,
|
|
target_update_freq: int = 10,
|
|
device: str = "cuda",
|
|
):
|
|
"""Initialize DQN agent."""
|
|
self.state_dim = state_dim
|
|
self.action_dim = action_dim
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon_start
|
|
self.epsilon_end = epsilon_end
|
|
self.epsilon_decay = epsilon_decay
|
|
self.batch_size = batch_size
|
|
self.target_update_freq = target_update_freq
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
|
|
self.q_network = QNetwork(state_dim, action_dim, hidden_layers).to(
|
|
self.device
|
|
)
|
|
self.target_network = QNetwork(state_dim, action_dim, hidden_layers).to(
|
|
self.device
|
|
)
|
|
self.target_network.load_state_dict(self.q_network.state_dict())
|
|
|
|
self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)
|
|
self.replay_buffer = ReplayBuffer(buffer_size)
|
|
|
|
self.episode_count = 0
|
|
|
|
def select_action(self, state: np.ndarray, training: bool = True) -> int:
|
|
"""Select action using epsilon-greedy policy."""
|
|
if training and random.random() < self.epsilon:
|
|
return random.randint(0, self.action_dim - 1)
|
|
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
q_values = self.q_network(state_tensor)
|
|
return q_values.argmax().item()
|
|
|
|
def store_transition(
|
|
self,
|
|
state: np.ndarray,
|
|
action: int,
|
|
reward: float,
|
|
next_state: np.ndarray,
|
|
done: bool,
|
|
):
|
|
"""Store transition in replay buffer."""
|
|
self.replay_buffer.push(state, action, reward, next_state, done)
|
|
|
|
def train(self) -> float:
|
|
"""Train the agent on a batch of experiences."""
|
|
if len(self.replay_buffer) < self.batch_size:
|
|
return 0.0
|
|
|
|
states, actions, rewards, next_states, dones = self.replay_buffer.sample(
|
|
self.batch_size
|
|
)
|
|
|
|
states = torch.FloatTensor(states).to(self.device)
|
|
actions = torch.LongTensor(actions).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
|
next_states = torch.FloatTensor(next_states).to(self.device)
|
|
dones = torch.FloatTensor(dones).to(self.device)
|
|
|
|
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
|
|
|
|
with torch.no_grad():
|
|
next_q_values = self.target_network(next_states).max(1)[0]
|
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
|
|
|
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
return loss.item()
|
|
|
|
def update_target_network(self):
|
|
"""Update target network with current Q-network weights."""
|
|
self.target_network.load_state_dict(self.q_network.state_dict())
|
|
|
|
def update_epsilon(self):
|
|
"""Decay epsilon for exploration."""
|
|
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
|
|
|
|
def end_episode(self):
|
|
"""Called at the end of each episode."""
|
|
self.episode_count += 1
|
|
if self.episode_count % self.target_update_freq == 0:
|
|
self.update_target_network()
|
|
self.update_epsilon()
|
|
|
|
def save(self, path: str):
|
|
"""Save model checkpoint."""
|
|
torch.save(
|
|
{
|
|
"q_network": self.q_network.state_dict(),
|
|
"target_network": self.target_network.state_dict(),
|
|
"optimizer": self.optimizer.state_dict(),
|
|
"epsilon": self.epsilon,
|
|
"episode_count": self.episode_count,
|
|
},
|
|
path,
|
|
)
|
|
|
|
def load(self, path: str):
|
|
"""Load model checkpoint."""
|
|
checkpoint = torch.load(path, map_location=self.device)
|
|
self.q_network.load_state_dict(checkpoint["q_network"])
|
|
self.target_network.load_state_dict(checkpoint["target_network"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
|
self.epsilon = checkpoint["epsilon"]
|
|
self.episode_count = checkpoint["episode_count"]
|