340 lines
14 KiB
Python
340 lines
14 KiB
Python
"""QMIX baseline agent for cooperative corridor VSL control."""
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
from collections import deque
|
|
from typing import Deque, Dict, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
|
|
class SharedUtilityNetwork(nn.Module):
|
|
"""Shared per-agent utility network for homogeneous corridor controllers."""
|
|
|
|
def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256):
|
|
super().__init__()
|
|
self.network = nn.Sequential(
|
|
nn.Linear(local_obs_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, num_actions),
|
|
)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0.0)
|
|
nn.init.orthogonal_(self.network[-1].weight, gain=0.01)
|
|
|
|
def forward(self, local_obs: torch.Tensor) -> torch.Tensor:
|
|
return self.network(local_obs)
|
|
|
|
|
|
class QMixer(nn.Module):
|
|
"""Monotonic mixing network conditioned on the global state."""
|
|
|
|
def __init__(
|
|
self,
|
|
num_agents: int,
|
|
state_dim: int,
|
|
mixing_hidden_dim: int = 256,
|
|
):
|
|
super().__init__()
|
|
self.num_agents = num_agents
|
|
self.state_dim = state_dim
|
|
self.mixing_hidden_dim = mixing_hidden_dim
|
|
|
|
self.hyper_w1 = nn.Sequential(
|
|
nn.Linear(state_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, num_agents * mixing_hidden_dim),
|
|
)
|
|
self.hyper_b1 = nn.Linear(state_dim, mixing_hidden_dim)
|
|
self.hyper_w2 = nn.Sequential(
|
|
nn.Linear(state_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, mixing_hidden_dim),
|
|
)
|
|
self.hyper_b2 = nn.Sequential(
|
|
nn.Linear(state_dim, mixing_hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(mixing_hidden_dim, 1),
|
|
)
|
|
self.activation = nn.ELU()
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for module in self.modules():
|
|
if isinstance(module, nn.Linear):
|
|
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
|
nn.init.constant_(module.bias, 0.0)
|
|
nn.init.orthogonal_(self.hyper_b2[-1].weight, gain=1.0)
|
|
|
|
def forward(self, agent_qs: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
batch_size = agent_qs.size(0)
|
|
flat_state = state.view(batch_size, -1)
|
|
|
|
w1 = torch.abs(self.hyper_w1(flat_state)).view(
|
|
batch_size,
|
|
self.num_agents,
|
|
self.mixing_hidden_dim,
|
|
)
|
|
b1 = self.hyper_b1(flat_state).view(batch_size, 1, self.mixing_hidden_dim)
|
|
hidden = self.activation(torch.bmm(agent_qs.unsqueeze(1), w1) + b1)
|
|
|
|
w2 = torch.abs(self.hyper_w2(flat_state)).view(
|
|
batch_size,
|
|
self.mixing_hidden_dim,
|
|
1,
|
|
)
|
|
b2 = self.hyper_b2(flat_state).view(batch_size, 1, 1)
|
|
return torch.bmm(hidden, w2).squeeze(1) + b2.squeeze(1)
|
|
|
|
|
|
class QMIXAgent:
|
|
"""Standard QMIX with parameter-shared utilities for cooperative discrete VSL."""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_edges: int,
|
|
num_actions_per_edge: int,
|
|
hidden_dim: int = 256,
|
|
mixing_hidden_dim: int = 256,
|
|
learning_rate: float = 1e-3,
|
|
gamma: float = 0.99,
|
|
epsilon_start: float = 1.0,
|
|
epsilon_end: float = 0.01,
|
|
epsilon_decay: int = 200,
|
|
buffer_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10,
|
|
device: str = "cuda",
|
|
edge_feature_dim: int = 3,
|
|
time_feature_dim: int = 3,
|
|
total_edge_count: int | None = None,
|
|
controlled_start_index: int = 0,
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.state_dim = state_dim
|
|
self.num_agents = int(num_edges)
|
|
self.num_actions_per_agent = int(num_actions_per_edge)
|
|
self.gamma = float(gamma)
|
|
self.epsilon_start = float(epsilon_start)
|
|
self.epsilon_end = float(epsilon_end)
|
|
self.epsilon_decay = int(epsilon_decay)
|
|
self.batch_size = int(batch_size)
|
|
self.target_update = max(int(target_update), 1)
|
|
self.edge_feature_dim = int(edge_feature_dim)
|
|
self.time_feature_dim = int(time_feature_dim)
|
|
self.speed_feature_dim = 1
|
|
self.last_reward_dim = 1
|
|
self.agent_id_dim = 1
|
|
self.total_edge_count = int(total_edge_count if total_edge_count is not None else num_edges)
|
|
self.controlled_start_index = int(controlled_start_index)
|
|
self.controlled_end_index = self.controlled_start_index + self.num_agents
|
|
if self.controlled_end_index > self.total_edge_count:
|
|
raise ValueError("controlled action slice exceeds total edge count")
|
|
|
|
self.local_obs_dim = (
|
|
self.edge_feature_dim
|
|
+ self.speed_feature_dim
|
|
+ self.time_feature_dim
|
|
+ self.last_reward_dim
|
|
+ self.agent_id_dim
|
|
)
|
|
|
|
self.utility_net = self._build_utility_network(hidden_dim).to(self.device)
|
|
self.target_utility_net = self._build_utility_network(hidden_dim).to(self.device)
|
|
self.mixer = self._build_mixer(mixing_hidden_dim).to(self.device)
|
|
self.target_mixer = self._build_mixer(mixing_hidden_dim).to(self.device)
|
|
self._sync_target_networks()
|
|
self.target_utility_net.eval()
|
|
self.target_mixer.eval()
|
|
|
|
self.optimizer = optim.Adam(
|
|
list(self.utility_net.parameters()) + list(self.mixer.parameters()),
|
|
lr=learning_rate,
|
|
)
|
|
self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque(
|
|
maxlen=buffer_size
|
|
)
|
|
self.env_steps = 0
|
|
self.update_steps = 0
|
|
|
|
agent_ids = np.linspace(0.0, 1.0, self.num_agents, dtype=np.float32)
|
|
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, self.num_agents, 1)
|
|
|
|
def _build_utility_network(self, hidden_dim: int) -> nn.Module:
|
|
return SharedUtilityNetwork(
|
|
local_obs_dim=self.local_obs_dim,
|
|
num_actions=self.num_actions_per_agent,
|
|
hidden_dim=hidden_dim,
|
|
)
|
|
|
|
def _build_mixer(self, mixing_hidden_dim: int) -> nn.Module:
|
|
return QMixer(
|
|
num_agents=self.num_agents,
|
|
state_dim=self.state_dim,
|
|
mixing_hidden_dim=mixing_hidden_dim,
|
|
)
|
|
|
|
def _current_epsilon(self) -> float:
|
|
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(
|
|
-float(self.env_steps) / max(float(self.epsilon_decay), 1.0)
|
|
)
|
|
|
|
def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor:
|
|
if state_tensor.dim() == 1:
|
|
state_tensor = state_tensor.unsqueeze(0)
|
|
|
|
batch_size = state_tensor.size(0)
|
|
edge_block = self.total_edge_count * self.edge_feature_dim
|
|
speed_block_start = edge_block
|
|
speed_block_end = speed_block_start + self.total_edge_count
|
|
global_block_start = speed_block_end
|
|
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
|
|
|
|
edge_features = state_tensor[:, :edge_block].view(batch_size, self.total_edge_count, self.edge_feature_dim)
|
|
edge_features = edge_features[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.total_edge_count, 1)
|
|
local_speed_limits = local_speed_limits[:, self.controlled_start_index:self.controlled_end_index, :]
|
|
global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1)
|
|
global_features = global_features.expand(-1, self.num_agents, -1)
|
|
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
|
|
return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1)
|
|
|
|
def _compute_agent_q_values_with_net(
|
|
self,
|
|
state_tensor: torch.Tensor,
|
|
utility_net: nn.Module,
|
|
) -> torch.Tensor:
|
|
local_obs = self._build_local_obs(state_tensor)
|
|
batch_size = local_obs.size(0)
|
|
q_values = utility_net(local_obs.view(batch_size * self.num_agents, self.local_obs_dim))
|
|
return q_values.view(batch_size, self.num_agents, self.num_actions_per_agent)
|
|
|
|
def _compute_agent_q_values(self, state_tensor: torch.Tensor) -> torch.Tensor:
|
|
return self._compute_agent_q_values_with_net(state_tensor, self.utility_net)
|
|
|
|
def _compute_target_next_agent_q(self, next_states_t: torch.Tensor) -> torch.Tensor:
|
|
target_next_q_all = self._compute_agent_q_values_with_net(next_states_t, self.target_utility_net)
|
|
return target_next_q_all.max(dim=2)[0]
|
|
|
|
def _greedy_action(self, state: np.ndarray) -> np.ndarray:
|
|
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
|
|
with torch.no_grad():
|
|
q_values = self._compute_agent_q_values(state_tensor)
|
|
return q_values.argmax(dim=2).squeeze(0).cpu().numpy().astype(np.int64)
|
|
|
|
def select_action(
|
|
self,
|
|
state: np.ndarray,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, float, float]:
|
|
if deterministic:
|
|
return self._greedy_action(state), 0.0, 0.0
|
|
|
|
epsilon = self._current_epsilon()
|
|
self.env_steps += 1
|
|
if random.random() < epsilon:
|
|
action = np.array(
|
|
[random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)],
|
|
dtype=np.int64,
|
|
)
|
|
else:
|
|
action = self._greedy_action(state)
|
|
return action, 0.0, 0.0
|
|
|
|
def store_transition(self, state, action, reward, next_state, done):
|
|
transition = (
|
|
np.asarray(state, dtype=np.float32),
|
|
np.asarray(action, dtype=np.int64),
|
|
float(reward),
|
|
np.asarray(next_state, dtype=np.float32),
|
|
bool(done),
|
|
)
|
|
self.replay_buffer.append(transition)
|
|
|
|
def _sync_target_networks(self):
|
|
self.target_utility_net.load_state_dict(self.utility_net.state_dict())
|
|
self.target_mixer.load_state_dict(self.mixer.state_dict())
|
|
|
|
def _sample_batch(self):
|
|
batch = random.sample(self.replay_buffer, self.batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device)
|
|
actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device)
|
|
rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device)
|
|
dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
return states_t, actions_t, rewards_t, next_states_t, dones_t
|
|
|
|
def update(self) -> Dict[str, float]:
|
|
if len(self.replay_buffer) < self.batch_size:
|
|
return {}
|
|
|
|
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
|
|
|
|
current_agent_q_all = self._compute_agent_q_values(states_t)
|
|
current_agent_q = current_agent_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2)
|
|
current_q_tot = self.mixer(current_agent_q, states_t)
|
|
|
|
with torch.no_grad():
|
|
target_next_agent_q = self._compute_target_next_agent_q(next_states_t)
|
|
target_q_tot = rewards_t + (1.0 - dones_t) * self.gamma * self.target_mixer(
|
|
target_next_agent_q,
|
|
next_states_t,
|
|
)
|
|
|
|
loss = nn.MSELoss()(current_q_tot, target_q_tot)
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(
|
|
list(self.utility_net.parameters()) + list(self.mixer.parameters()),
|
|
1.0,
|
|
)
|
|
self.optimizer.step()
|
|
|
|
self.update_steps += 1
|
|
if self.update_steps % self.target_update == 0:
|
|
self._sync_target_networks()
|
|
|
|
value_loss = float(loss.item())
|
|
return {"loss": value_loss, "value_loss": value_loss}
|
|
|
|
def save(self, path: str):
|
|
torch.save(
|
|
{
|
|
"utility_net": self.utility_net.state_dict(),
|
|
"target_utility_net": self.target_utility_net.state_dict(),
|
|
"mixer": self.mixer.state_dict(),
|
|
"target_mixer": self.target_mixer.state_dict(),
|
|
"optimizer": self.optimizer.state_dict(),
|
|
"env_steps": int(self.env_steps),
|
|
"update_steps": int(self.update_steps),
|
|
},
|
|
path,
|
|
)
|
|
|
|
def load(self, path: str):
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
self.utility_net.load_state_dict(checkpoint["utility_net"])
|
|
self.target_utility_net.load_state_dict(checkpoint["target_utility_net"])
|
|
self.mixer.load_state_dict(checkpoint["mixer"])
|
|
self.target_mixer.load_state_dict(checkpoint["target_mixer"])
|
|
optimizer_state = checkpoint.get("optimizer")
|
|
if optimizer_state is not None:
|
|
self.optimizer.load_state_dict(optimizer_state)
|
|
self.env_steps = int(checkpoint.get("env_steps", 0))
|
|
self.update_steps = int(checkpoint.get("update_steps", 0))
|