ctm-dqn/agents/qmix_agent.py

340 lines
14 KiB
Python

"""QMIX baseline agent for cooperative corridor VSL control."""
from __future__ import annotations
import random
from collections import deque
from typing import Deque, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
class SharedUtilityNetwork(nn.Module):
"""Shared per-agent utility network for homogeneous corridor controllers."""
def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(local_obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_actions),
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
nn.init.orthogonal_(self.network[-1].weight, gain=0.01)
def forward(self, local_obs: torch.Tensor) -> torch.Tensor:
return self.network(local_obs)
class QMixer(nn.Module):
"""Monotonic mixing network conditioned on the global state."""
def __init__(
self,
num_agents: int,
state_dim: int,
mixing_hidden_dim: int = 256,
):
super().__init__()
self.num_agents = num_agents
self.state_dim = state_dim
self.mixing_hidden_dim = mixing_hidden_dim
self.hyper_w1 = nn.Sequential(
nn.Linear(state_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim),
)
self.hyper_b1 = nn.Linear(state_dim, mixing_hidden_dim)
self.hyper_w2 = nn.Sequential(
nn.Linear(state_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, mixing_hidden_dim),
)
self.hyper_b2 = nn.Sequential(
nn.Linear(state_dim, mixing_hidden_dim),
nn.ReLU(),
nn.Linear(mixing_hidden_dim, 1),
)
self.activation = nn.ELU()
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
nn.init.constant_(module.bias, 0.0)
nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0)
def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
batch_size = agent_qs.size(0)
flat_state = state.view(batch_size, -1)
w1 = torch.abs(self.hyper_w1(flat_state)).view(
batch_size,
self.num_agents,
self.mixing_hidden_dim,
)
b1 = self.hyper_b1(flat_state).view(batch_size, 1, self.mixing_hidden_dim)
hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1)
w2 = torch.abs(self.hyper_w2(flat_state)).view(
batch_size,
self.mixing_hidden_dim,
1,
)
b2 = self.hyper_b2(flat_state).view(batch_size, 1, 1)
return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1)
class QMIXAgent:
"""Standard QMIX with parameter-shared utilities for cooperative discrete VSL."""
def __init__(
self,
state_dim: int,
num_edges: int,
num_actions_per_edge: int,
hidden_dim: int = 256,
mixing_hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda",
edge_feature_dim: int = 3,
time_feature_dim: int = 3,
total_edge_count: int | None = None,
controlled_start_index: int = 0,
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.state_dim = state_dim
self.num_agents = int(num_edges)
self.num_actions_per_agent = int(num_actions_per_edge)
self.gamma = float(gamma)
self.epsilon_start = float(epsilon_start)
self.epsilon_end = float(epsilon_end)
self.epsilon_decay = int(epsilon_decay)
self.batch_size = int(batch_size)
self.target_update = max(int(target_update), 1)
self.edge_feature_dim = int(edge_feature_dim)
self.time_feature_dim = int(time_feature_dim)
self.speed_feature_dim = 1
self.last_reward_dim = 1
self.agent_id_dim = 1
self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_edges)
self.controlled_start_index = int(controlled_start_index)
self.controlled_end_index = self.controlled_start_index + self.num_agents
if self.controlled_end_index > self.total_edge_count:
raise ValueError("controlled action slice exceeds total edge count")
self.local_obs_dim = (
self.edge_feature_dim
+ self.speed_feature_dim
+ self.time_feature_dim
+ self.last_reward_dim
+ self.agent_id_dim
)
self.utility_net = self._build_utility_network(hidden_dim).to(self.device)
self.target_utility_net = self._build_utility_network(hidden_dim).to(self.device)
self.mixer = self._build_mixer(mixing_hidden_dim).to(self.device)
self.target_mixer = self._build_mixer(mixing_hidden_dim).to(self.device)
self._sync_target_networks()
self.target_utility_net.eval()
self.target_mixer.eval()
self.optimizer = optim.Adam(
list(self.utility_net.parameters()) + list(self.mixer.parameters()),
lr=learning_rate,
)
self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque(
maxlen=buffer_size
)
self.env_steps = 0
self.update_steps = 0
agent_ids = np.linspace(0.0, 1.0, self.num_agents, dtype=np.float32)
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, self.num_agents, 1)
def _build_utility_network(self, hidden_dim: int) -> nn.Module:
return SharedUtilityNetwork(
local_obs_dim=self.local_obs_dim,
num_actions=self.num_actions_per_agent,
hidden_dim=hidden_dim,
)
def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module:
return QMixer(
num_agents=self.num_agents,
state_dim=self.state_dim,
mixing_hidden_dim=mixing_hidden_dim,
)
def _current_epsilon(self) -> float:
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(
-float(self.env_steps) / max(float(self.epsilon_decay), 1.0)
)
def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor:
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
batch_size = state_tensor.size(0)
edge_block = self.total_edge_count * self.edge_feature_dim
speed_block_start = edge_block
speed_block_end = speed_block_start + self.total_edge_count
global_block_start = speed_block_end
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1)
global_features = global_features.expand(-1, self.num_agents, -1)
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1)
def _compute_agent_q_values_with_net(
self,
state_tensor: torch.Tensor,
utility_net: nn.Module,
) -> torch.Tensor:
local_obs = self._build_local_obs(state_tensor)
batch_size = local_obs.size(0)
q_values = utility_net(local_obs.view(batch_size * self.num_agents, self.local_obs_dim))
return q_values.view(batch_size, self.num_agents, self.num_actions_per_agent)
def _compute_agent_q_values(self, state_tensor: torch.Tensor) -> torch.Tensor:
return self._compute_agent_q_values_with_net(state_tensor, self.utility_net)
def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor:
target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net)
return target_next_q_all.max(dim=2)[0]
def _greedy_action(self, state: np.ndarray) -> np.ndarray:
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
with torch.no_grad():
q_values = self._compute_agent_q_values(state_tensor)
return q_values.argmax(dim=2).squeeze(0).cpu().numpy().astype(np.int64)
def select_action(
self,
state: np.ndarray,
deterministic: bool = False,
) -> Tuple[np.ndarray, float, float]:
if deterministic:
return self._greedy_action(state), 0.0, 0.0
epsilon = self._current_epsilon()
self.env_steps += 1
if random.random() < epsilon:
action = np.array(
[random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)],
dtype=np.int64,
)
else:
action = self._greedy_action(state)
return action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
transition = (
np.asarray(state, dtype=np.float32),
np.asarray(action, dtype=np.int64),
float(reward),
np.asarray(next_state, dtype=np.float32),
bool(done),
)
self.replay_buffer.append(transition)
def _sync_target_networks(self):
self.target_utility_net.load_state_dict(self.utility_net.state_dict())
self.target_mixer.load_state_dict(self.mixer.state_dict())
def _sample_batch(self):
batch = random.sample(self.replay_buffer, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device)
actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device)
rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device)
dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
return states_t, actions_t, rewards_t, next_states_t, dones_t
def update(self) -> Dict[str, float]:
if len(self.replay_buffer) < self.batch_size:
return {}
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
current_agent_q_all = self._compute_agent_q_values(states_t)
current_agent_q = current_agent_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2)
current_q_tot = self.mixer(current_agent_q, states_t)
with torch.no_grad():
target_next_agent_q = self._compute_target_next_agent_q(next_states_t)
target_q_tot = rewards_t + (1.0 - dones_t) * self.gamma * self.target_mixer(
target_next_agent_q,
next_states_t,
)
loss = nn.MSELoss()(current_q_tot, target_q_tot)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(self.utility_net.parameters()) + list(self.mixer.parameters()),
1.0,
)
self.optimizer.step()
self.update_steps += 1
if self.update_steps % self.target_update == 0:
self._sync_target_networks()
value_loss = float(loss.item())
return {"loss": value_loss, "value_loss": value_loss}
def save(self, path: str):
torch.save(
{
"utility_net": self.utility_net.state_dict(),
"target_utility_net": self.target_utility_net.state_dict(),
"mixer": self.mixer.state_dict(),
"target_mixer": self.target_mixer.state_dict(),
"optimizer": self.optimizer.state_dict(),
"env_steps": int(self.env_steps),
"update_steps": int(self.update_steps),
},
path,
)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.utility_net.load_state_dict(checkpoint["utility_net"])
self.target_utility_net.load_state_dict(checkpoint["target_utility_net"])
self.mixer.load_state_dict(checkpoint["mixer"])
self.target_mixer.load_state_dict(checkpoint["target_mixer"])
optimizer_state = checkpoint.get("optimizer")
if optimizer_state is not None:
self.optimizer.load_state_dict(optimizer_state)
self.env_steps = int(checkpoint.get("env_steps", 0))
self.update_steps = int(checkpoint.get("update_steps", 0))