ctm-dqn/agents/dqn_agent.py

135 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DQN Agent for SUMO VSL Control
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class DQNNetwork(nn.Module):
def __init__(self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256):
super().__init__()
self.num_edges = num_edges
self.num_actions_per_edge = num_actions_per_edge
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_edges * num_actions_per_edge)
)
def forward(self, x):
q_values = self.network(x)
return q_values.view(-1, self.num_edges, self.num_actions_per_edge)
class DQNAgent:
def __init__(
self,
state_dim: int,
num_edges: int,
num_actions_per_edge: int,
hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda"
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.num_edges = num_edges
self.num_actions_per_edge = num_actions_per_edge
self.gamma = gamma
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
self.policy_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device)
self.target_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.memory = deque(maxlen=buffer_size)
self.steps_done = 0
def select_action(self, state: np.ndarray, deterministic: bool = False):
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
np.exp(-1. * self.steps_done / self.epsilon_decay)
self.steps_done += 1
if deterministic or random.random() > epsilon:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state_tensor) # shape: (1, num_edges, num_actions_per_edge)
actions = q_values.argmax(dim=2).squeeze(0).cpu().numpy() # shape: (num_edges,)
return actions
else:
return np.array([random.randrange(self.num_actions_per_edge) for _ in range(self.num_edges)])
def store_transition(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def update(self):
if len(self.memory) < self.batch_size:
return {}
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(np.array(actions)).to(self.device) # shape: (batch, num_edges)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
# Q值: (batch, num_edges, num_actions_per_edge)
current_q_all = self.policy_net(states)
# 选择每条边对应动作的Q值: (batch, num_edges)
current_q = current_q_all.gather(2, actions.unsqueeze(2)).squeeze(2)
# 平均所有边的Q值作为状态价值
current_q = current_q.mean(dim=1, keepdim=True)
with torch.no_grad():
next_q_all = self.target_net(next_states)
# 每条边选最大Q值然后平均
next_q = next_q_all.max(dim=2)[0].mean(dim=1, keepdim=True)
target_q = rewards + (1 - dones) * self.gamma * next_q
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
return {"loss": loss.item()}
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path: str):
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, path)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])