ctm-dqn/agents/madqn_agent.py

206 lines
7.6 KiB
Python

"""Independent multi-agent DQN agent for corridor VSL control."""
from __future__ import annotations
import random
from collections import deque
from typing import Deque, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
class MADQNNetwork(nn.Module):
"""Per-agent Q-network that maps the global state to local action values."""
def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = 256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_actions),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
class MADQNAgent:
"""Independent DQN ensemble with one Q-network per controlled edge."""
def __init__(
self,
state_dim: int,
num_edges: int,
num_actions_per_edge: int,
hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda",
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.num_agents = num_edges
self.num_actions_per_agent = num_actions_per_edge
self.gamma = gamma
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = max(int(target_update), 1)
self.q_nets = nn.ModuleList(
[
MADQNNetwork(
state_dim=state_dim,
num_actions=num_actions_per_edge,
hidden_dim=hidden_dim,
)
for _ in range(num_edges)
]
).to(self.device)
self.target_q_nets = nn.ModuleList(
[
MADQNNetwork(
state_dim=state_dim,
num_actions=num_actions_per_edge,
hidden_dim=hidden_dim,
)
for _ in range(num_edges)
]
).to(self.device)
self._sync_target_network()
self.target_q_nets.eval()
self.optimizer = optim.Adam(self.q_nets.parameters(), lr=learning_rate)
self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque(
maxlen=buffer_size
)
self.env_steps = 0
self.update_steps = 0
def _current_epsilon(self) -> float:
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(
-float(self.env_steps) / max(float(self.epsilon_decay), 1.0)
)
def _greedy_action(self, state: np.ndarray) -> np.ndarray:
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
actions = []
with torch.no_grad():
for q_net in self.q_nets:
q_values = q_net(state_tensor)
actions.append(int(q_values.argmax(dim=1).item()))
return np.asarray(actions, dtype=np.int64)
def select_action(
self,
state: np.ndarray,
deterministic: bool = False,
) -> Tuple[np.ndarray, float, float]:
if deterministic:
return self._greedy_action(state), 0.0, 0.0
epsilon = self._current_epsilon()
self.env_steps += 1
if random.random() < epsilon:
action = np.array(
[random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)],
dtype=np.int64,
)
else:
action = self._greedy_action(state)
return action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
transition = (
np.asarray(state, dtype=np.float32),
np.asarray(action, dtype=np.int64),
float(reward),
np.asarray(next_state, dtype=np.float32),
bool(done),
)
self.replay_buffer.append(transition)
def _sync_target_network(self):
for q_net, target_q_net in zip(self.q_nets, self.target_q_nets):
target_q_net.load_state_dict(q_net.state_dict())
def _sample_batch(self):
batch = random.sample(self.replay_buffer, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device)
actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device)
rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device)
dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
return states_t, actions_t, rewards_t, next_states_t, dones_t
def update(self) -> Dict[str, float]:
if len(self.replay_buffer) < self.batch_size:
return {}
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
current_q_values = []
target_q_values = []
for agent_idx, (q_net, target_q_net) in enumerate(zip(self.q_nets, self.target_q_nets)):
current_q = q_net(states_t).gather(1, actions_t[:, agent_idx].unsqueeze(1))
with torch.no_grad():
next_q = target_q_net(next_states_t).max(dim=1, keepdim=True)[0]
target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q
current_q_values.append(current_q)
target_q_values.append(target_q)
current_q_tensor = torch.cat(current_q_values, dim=1)
target_q_tensor = torch.cat(target_q_values, dim=1)
loss = nn.MSELoss()(current_q_tensor, target_q_tensor)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_nets.parameters(), 1.0)
self.optimizer.step()
self.update_steps += 1
if self.update_steps % self.target_update == 0:
self._sync_target_network()
value_loss = float(loss.item())
return {"loss": value_loss, "value_loss": value_loss}
def save(self, path: str):
torch.save(
{
"q_nets": [q_net.state_dict() for q_net in self.q_nets],
"target_q_nets": [target_q_net.state_dict() for target_q_net in self.target_q_nets],
"optimizer": self.optimizer.state_dict(),
"env_steps": int(self.env_steps),
"update_steps": int(self.update_steps),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
for q_net, state_dict in zip(self.q_nets, checkpoint["q_nets"]):
q_net.load_state_dict(state_dict)
for target_q_net, state_dict in zip(self.target_q_nets, checkpoint["target_q_nets"]):
target_q_net.load_state_dict(state_dict)
optimizer_state = checkpoint.get("optimizer")
if optimizer_state is not None:
self.optimizer.load_state_dict(optimizer_state)
self.env_steps = int(checkpoint.get("env_steps", 0))
self.update_steps = int(checkpoint.get("update_steps", 0))