""" DQN Agent for SUMO VSL Control """ import torch import torch.nn as nn import torch.optim as optim import numpy as np from collections import deque import random class DQNNetwork(nn.Module): def __init__(self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256): super().__init__() self.num_edges = num_edges self.num_actions_per_edge = num_actions_per_edge self.network = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, num_edges * num_actions_per_edge) ) def forward(self, x): q_values = self.network(x) return q_values.view(-1, self.num_edges, self.num_actions_per_edge) class DQNAgent: def __init__( self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256, learning_rate: float = 1e-3, gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01, epsilon_decay: int = 200, buffer_size: int = 10000, batch_size: int = 64, target_update: int = 10, device: str = "cuda" ): self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.num_edges = num_edges self.num_actions_per_edge = num_actions_per_edge self.gamma = gamma self.epsilon_start = epsilon_start self.epsilon_end = epsilon_end self.epsilon_decay = epsilon_decay self.batch_size = batch_size self.target_update = target_update self.policy_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device) self.target_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate) self.memory = deque(maxlen=buffer_size) self.steps_done = 0 def select_action(self, state: np.ndarray, deterministic: bool = False): epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ np.exp(-1. * self.steps_done / self.epsilon_decay) self.steps_done += 1 if deterministic or random.random() > epsilon: with torch.no_grad(): state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) q_values = self.policy_net(state_tensor) # shape: (1, num_edges, num_actions_per_edge) actions = q_values.argmax(dim=2).squeeze(0).cpu().numpy() # shape: (num_edges,) return actions else: return np.array([random.randrange(self.num_actions_per_edge) for _ in range(self.num_edges)]) def store_transition(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def update(self): if len(self.memory) < self.batch_size: return {} batch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.FloatTensor(np.array(states)).to(self.device) actions = torch.LongTensor(np.array(actions)).to(self.device) # shape: (batch, num_edges) rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device) next_states = torch.FloatTensor(np.array(next_states)).to(self.device) dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device) # Q值: (batch, num_edges, num_actions_per_edge) current_q_all = self.policy_net(states) # 选择每条边对应动作的Q值: (batch, num_edges) current_q = current_q_all.gather(2, actions.unsqueeze(2)).squeeze(2) # 平均所有边的Q值作为状态价值 current_q = current_q.mean(dim=1, keepdim=True) with torch.no_grad(): next_q_all = self.target_net(next_states) # 每条边选最大Q值,然后平均 next_q = next_q_all.max(dim=2)[0].mean(dim=1, keepdim=True) target_q = rewards + (1 - dones) * self.gamma * next_q loss = nn.MSELoss()(current_q, target_q) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0) self.optimizer.step() return {"loss": loss.item()} def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path: str): torch.save({ 'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(), 'optimizer': self.optimizer.state_dict(), }, path) def load(self, path: str): checkpoint = torch.load(path, map_location=self.device, weights_only=False) self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) self.optimizer.load_state_dict(checkpoint['optimizer'])