44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
"""Shared-parameter multi-head DDQN agent for corridor VSL control."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from agents.dqn_agent import DQNAgent
|
|
|
|
|
|
class DDQNAgent(DQNAgent):
|
|
"""Shared-parameter Double DQN with one action head per controlled edge."""
|
|
|
|
def update(self) -> Dict[str, float]:
|
|
if len(self.replay_buffer) < self.batch_size:
|
|
return {}
|
|
|
|
states_t, actions_t, rewards_t, next_states_t, dones_t = self._sample_batch()
|
|
|
|
current_q_all = self.q_net(states_t)
|
|
current_q = current_q_all.gather(2, actions_t.unsqueeze(2)).squeeze(2)
|
|
|
|
with torch.no_grad():
|
|
online_next_q_all = self.q_net(next_states_t)
|
|
next_actions = online_next_q_all.argmax(dim=2, keepdim=True)
|
|
target_next_q_all = self.target_q_net(next_states_t)
|
|
next_q = target_next_q_all.gather(2, next_actions).squeeze(2)
|
|
target_q = rewards_t + (1.0 - dones_t) * self.gamma * next_q
|
|
|
|
loss = nn.MSELoss()(current_q, target_q)
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 1.0)
|
|
self.optimizer.step()
|
|
|
|
self.update_steps += 1
|
|
if self.update_steps % self.target_update == 0:
|
|
self._sync_target_network()
|
|
|
|
value_loss = float(loss.item())
|
|
return {"loss": value_loss, "value_loss": value_loss}
|