"""Shared-parameter multi-head DDQN agent for corridor VSL control.""" from __future__ import annotations from typing import Dict import torch import torch.nn as nn from agents.dqn_agent import DQNAgent class DDQNAgent(DQNAgent): """Shared-parameter Double DQN with one action head per controlled edge.""" def update(self) -> Dict[str, float]: if len(self.replay_buffer) < self.batch_size: return {} states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch() current_q_all = self.q_net(states_t) current_q = current_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2) with torch.no_grad(): online_next_q_all = self.q_net(next_states_t) next_actions = online_next_q_all.argmax(dim=2, keepdim=True) target_next_q_all = self.target_q_net(next_states_t) next_q = target_next_q_all.gather(2, next_actions).squeeze(2) target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q loss = nn.MSELoss()(current_q, target_q) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0) self.optimizer.step() self.update_steps += 1 if self.update_steps % self.target_update == 0: self._sync_target_network() value_loss = float(loss.item()) return {"loss": value_loss, "value_loss": value_loss}