206 lines
7.6 KiB
Python
206 lines
7.6 KiB
Python
"""Independent multi-agent DQN agent for corridor VSL control."""
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
from collections import deque
|
|
from typing import Deque, Dict, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
|
|
class MADQNNetwork(nn.Module):
|
|
"""Per-agent Q-network that maps the global state to local action values."""
|
|
|
|
def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = 256):
|
|
super().__init__()
|
|
self.network = nn.Sequential(
|
|
nn.Linear(state_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, num_actions),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.network(x)
|
|
|
|
|
|
class MADQNAgent:
|
|
"""Independent DQN ensemble with one Q-network per controlled edge."""
|
|
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_edges: int,
|
|
num_actions_per_edge: int,
|
|
hidden_dim: int = 256,
|
|
learning_rate: float = 1e-3,
|
|
gamma: float = 0.99,
|
|
epsilon_start: float = 1.0,
|
|
epsilon_end: float = 0.01,
|
|
epsilon_decay: int = 200,
|
|
buffer_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10,
|
|
device: str = "cuda",
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.num_agents = num_edges
|
|
self.num_actions_per_agent = num_actions_per_edge
|
|
self.gamma = gamma
|
|
self.epsilon_start = epsilon_start
|
|
self.epsilon_end = epsilon_end
|
|
self.epsilon_decay = epsilon_decay
|
|
self.batch_size = batch_size
|
|
self.target_update = max(int(target_update), 1)
|
|
|
|
self.q_nets = nn.ModuleList(
|
|
[
|
|
MADQNNetwork(
|
|
state_dim=state_dim,
|
|
num_actions=num_actions_per_edge,
|
|
hidden_dim=hidden_dim,
|
|
)
|
|
for _ in range(num_edges)
|
|
]
|
|
).to(self.device)
|
|
self.target_q_nets = nn.ModuleList(
|
|
[
|
|
MADQNNetwork(
|
|
state_dim=state_dim,
|
|
num_actions=num_actions_per_edge,
|
|
hidden_dim=hidden_dim,
|
|
)
|
|
for _ in range(num_edges)
|
|
]
|
|
).to(self.device)
|
|
self._sync_target_network()
|
|
self.target_q_nets.eval()
|
|
|
|
self.optimizer = optim.Adam(self.q_nets.parameters(), lr=learning_rate)
|
|
self.replay_buffer: Deque[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = deque(
|
|
maxlen=buffer_size
|
|
)
|
|
self.env_steps = 0
|
|
self.update_steps = 0
|
|
|
|
def _current_epsilon(self) -> float:
|
|
return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(
|
|
-float(self.env_steps) / max(float(self.epsilon_decay), 1.0)
|
|
)
|
|
|
|
def _greedy_action(self, state: np.ndarray) -> np.ndarray:
|
|
state_tensor = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
|
|
actions = []
|
|
with torch.no_grad():
|
|
for q_net in self.q_nets:
|
|
q_values = q_net(state_tensor)
|
|
actions.append(int(q_values.argmax(dim=1).item()))
|
|
return np.asarray(actions, dtype=np.int64)
|
|
|
|
def select_action(
|
|
self,
|
|
state: np.ndarray,
|
|
deterministic: bool = False,
|
|
) -> Tuple[np.ndarray, float, float]:
|
|
if deterministic:
|
|
return self._greedy_action(state), 0.0, 0.0
|
|
|
|
epsilon = self._current_epsilon()
|
|
self.env_steps += 1
|
|
|
|
if random.random() < epsilon:
|
|
action = np.array(
|
|
[random.randrange(self.num_actions_per_agent) for _ in range(self.num_agents)],
|
|
dtype=np.int64,
|
|
)
|
|
else:
|
|
action = self._greedy_action(state)
|
|
return action, 0.0, 0.0
|
|
|
|
def store_transition(self, state, action, reward, next_state, done):
|
|
transition = (
|
|
np.asarray(state, dtype=np.float32),
|
|
np.asarray(action, dtype=np.int64),
|
|
float(reward),
|
|
np.asarray(next_state, dtype=np.float32),
|
|
bool(done),
|
|
)
|
|
self.replay_buffer.append(transition)
|
|
|
|
def _sync_target_network(self):
|
|
for q_net, target_q_net in zip(self.q_nets, self.target_q_nets):
|
|
target_q_net.load_state_dict(q_net.state_dict())
|
|
|
|
def _sample_batch(self):
|
|
batch = random.sample(self.replay_buffer, self.batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
states_t = torch.as_tensor(np.array(states), dtype=torch.float32, device=self.device)
|
|
actions_t = torch.as_tensor(np.array(actions), dtype=torch.long, device=self.device)
|
|
rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
next_states_t = torch.as_tensor(np.array(next_states), dtype=torch.float32, device=self.device)
|
|
dones_t = torch.as_tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
return states_t, actions_t, rewards_t, next_states_t, dones_t
|
|
|
|
def update(self) -> Dict[str, float]:
|
|
if len(self.replay_buffer) < self.batch_size:
|
|
return {}
|
|
|
|
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
|
|
|
|
current_q_values = []
|
|
target_q_values = []
|
|
for agent_idx, (q_net, target_q_net) in enumerate(zip(self.q_nets, self.target_q_nets)):
|
|
current_q = q_net(states_t).gather(1, actions_t[:, agent_idx].unsqueeze(1))
|
|
with torch.no_grad():
|
|
next_q = target_q_net(next_states_t).max(dim=1, keepdim=True)[0]
|
|
target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q
|
|
current_q_values.append(current_q)
|
|
target_q_values.append(target_q)
|
|
|
|
current_q_tensor = torch.cat(current_q_values, dim=1)
|
|
target_q_tensor = torch.cat(target_q_values, dim=1)
|
|
loss = nn.MSELoss()(current_q_tensor, target_q_tensor)
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.q_nets.parameters(), 1.0)
|
|
self.optimizer.step()
|
|
|
|
self.update_steps += 1
|
|
if self.update_steps % self.target_update == 0:
|
|
self._sync_target_network()
|
|
|
|
value_loss = float(loss.item())
|
|
return {"loss": value_loss, "value_loss": value_loss}
|
|
|
|
def save(self, path: str):
|
|
torch.save(
|
|
{
|
|
"q_nets": [q_net.state_dict() for q_net in self.q_nets],
|
|
"target_q_nets": [target_q_net.state_dict() for target_q_net in self.target_q_nets],
|
|
"optimizer": self.optimizer.state_dict(),
|
|
"env_steps": int(self.env_steps),
|
|
"update_steps": int(self.update_steps),
|
|
},
|
|
path,
|
|
)
|
|
|
|
def load(self, path: str):
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
for q_net, state_dict in zip(self.q_nets, checkpoint["q_nets"]):
|
|
q_net.load_state_dict(state_dict)
|
|
for target_q_net, state_dict in zip(self.target_q_nets, checkpoint["target_q_nets"]):
|
|
target_q_net.load_state_dict(state_dict)
|
|
optimizer_state = checkpoint.get("optimizer")
|
|
if optimizer_state is not None:
|
|
self.optimizer.load_state_dict(optimizer_state)
|
|
self.env_steps = int(checkpoint.get("env_steps", 0))
|
|
self.update_steps = int(checkpoint.get("update_steps", 0))
|