ctm-dqn/agents/dqn_agent.py

197 lines
7.0 KiB
Python

"""Shared-parameter multi-head DQN agent for corridor VSL control."""
from __future__ import annotations
import random
from collections import deque
from typing import Deque, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
class DQNNetwork(nn.Module):
"""Shared Q-network with one discrete action head per controlled edge."""
def __init__(
self,
state_dim: int,
num_agents: int,
num_actions_per_agent: int,
hidden_dim: int = 256,
):
super().__init__()
self.num_agents = num_agents
self.num_actions_per_agent = num_actions_per_agent
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_agents * num_actions_per_agent),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q_values = self.network(x)
return q_values.view(-1, self.num_agents, self.num_actions_per_agent)
class DQNAgent:
"""Shared-parameter multi-head DQN with epsilon-greedy exploration."""
def __init__(
self,
state_dim: int,
num_edges: int,
num_actions_per_edge: int,
hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda",
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.num_agents = num_edges
self.num_actions_per_agent = num_actions_per_edge
self.gamma = gamma
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = max(int(target_update), 1)
self.q_net = DQNNetwork(
state_dim=state_dim,
num_agents=num_edges,
num_actions_per_agent=num_actions_per_edge,
hidden_dim=hidden_dim,
).to(self.device)
self.target_q_net = DQNNetwork(
state_dim=state_dim,
num_agents=num_edges,
num_actions_per_agent=num_actions_per_edge,
hidden_dim=hidden_dim,
).to(self.device)
self.target_q_net.load_state_dict(self.q_net.state_dict())
self.target_q_net.eval()
self.optimizer = optim.Adam(self.q_net.parameters(), lr=learning_rate)
self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque(
maxlen=buffer_size
)
self.env_steps = 0
self.update_steps = 0
def _current_epsilon(self) -> float:
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(
-float(self.env_steps) / max(float(self.epsilon_decay), 1.0)
)
def _greedy_action(self, state: np.ndarray) -> np.ndarray:
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
with torch.no_grad():
q_values = self.q_net(state_tensor)
return q_values.argmax(dim=2).squeeze(0).cpu().numpy().astype(np.int64)
def select_action(
self,
state: np.ndarray,
deterministic: bool = False,
) -> Tuple[np.ndarray, float, float]:
if deterministic:
return self._greedy_action(state), 0.0, 0.0
epsilon = self._current_epsilon()
self.env_steps += 1
if random.random() < epsilon:
action = np.array(
[random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)],
dtype=np.int64,
)
else:
action = self._greedy_action(state)
return action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
transition = (
np.asarray(state, dtype=np.float32),
np.asarray(action, dtype=np.int64),
float(reward),
np.asarray(next_state, dtype=np.float32),
bool(done),
)
self.replay_buffer.append(transition)
def _sync_target_network(self):
self.target_q_net.load_state_dict(self.q_net.state_dict())
def _sample_batch(self):
batch = random.sample(self.replay_buffer, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device)
actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device)
rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device)
dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
return states_t, actions_t, rewards_t, next_states_t, dones_t
def update(self) -> Dict[str, float]:
if len(self.replay_buffer) < self.batch_size:
return {}
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
current_q_all = self.q_net(states_t)
current_q = current_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2)
with torch.no_grad():
next_q_all = self.target_q_net(next_states_t)
next_q = next_q_all.max(dim=2)[0]
target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
self.optimizer.step()
self.update_steps += 1
if self.update_steps % self.target_update == 0:
self._sync_target_network()
value_loss = float(loss.item())
return {"loss": value_loss, "value_loss": value_loss}
def save(self, path: str):
torch.save(
{
"q_net": self.q_net.state_dict(),
"target_q_net": self.target_q_net.state_dict(),
"optimizer": self.optimizer.state_dict(),
"env_steps": int(self.env_steps),
"update_steps": int(self.update_steps),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.q_net.load_state_dict(checkpoint["q_net"])
self.target_q_net.load_state_dict(checkpoint["target_q_net"])
optimizer_state = checkpoint.get("optimizer")
if optimizer_state is not None:
self.optimizer.load_state_dict(optimizer_state)
self.env_steps = int(checkpoint.get("env_steps", 0))
self.update_steps = int(checkpoint.get("update_steps", 0))