135 lines
5.2 KiB
Python
135 lines
5.2 KiB
Python
"""
|
||
DQN Agent for SUMO VSL Control
|
||
"""
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
import numpy as np
|
||
from collections import deque
|
||
import random
|
||
|
||
|
||
class DQNNetwork(nn.Module):
|
||
def __init__(self, state_dim: int, num_edges: int, num_actions_per_edge: int, hidden_dim: int = 256):
|
||
super().__init__()
|
||
self.num_edges = num_edges
|
||
self.num_actions_per_edge = num_actions_per_edge
|
||
self.network = nn.Sequential(
|
||
nn.Linear(state_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim // 2, num_edges * num_actions_per_edge)
|
||
)
|
||
|
||
def forward(self, x):
|
||
q_values = self.network(x)
|
||
return q_values.view(-1, self.num_edges, self.num_actions_per_edge)
|
||
|
||
|
||
class DQNAgent:
|
||
def __init__(
|
||
self,
|
||
state_dim: int,
|
||
num_edges: int,
|
||
num_actions_per_edge: int,
|
||
hidden_dim: int = 256,
|
||
learning_rate: float = 1e-3,
|
||
gamma: float = 0.99,
|
||
epsilon_start: float = 1.0,
|
||
epsilon_end: float = 0.01,
|
||
epsilon_decay: int = 200,
|
||
buffer_size: int = 10000,
|
||
batch_size: int = 64,
|
||
target_update: int = 10,
|
||
device: str = "cuda"
|
||
):
|
||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||
self.num_edges = num_edges
|
||
self.num_actions_per_edge = num_actions_per_edge
|
||
self.gamma = gamma
|
||
self.epsilon_start = epsilon_start
|
||
self.epsilon_end = epsilon_end
|
||
self.epsilon_decay = epsilon_decay
|
||
self.batch_size = batch_size
|
||
self.target_update = target_update
|
||
|
||
self.policy_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device)
|
||
self.target_net = DQNNetwork(state_dim, num_edges, num_actions_per_edge, hidden_dim).to(self.device)
|
||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||
self.target_net.eval()
|
||
|
||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
||
self.memory = deque(maxlen=buffer_size)
|
||
self.steps_done = 0
|
||
|
||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||
np.exp(-1. * self.steps_done / self.epsilon_decay)
|
||
self.steps_done += 1
|
||
|
||
if deterministic or random.random() > epsilon:
|
||
with torch.no_grad():
|
||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||
q_values = self.policy_net(state_tensor) # shape: (1, num_edges, num_actions_per_edge)
|
||
actions = q_values.argmax(dim=2).squeeze(0).cpu().numpy() # shape: (num_edges,)
|
||
return actions
|
||
else:
|
||
return np.array([random.randrange(self.num_actions_per_edge) for _ in range(self.num_edges)])
|
||
|
||
def store_transition(self, state, action, reward, next_state, done):
|
||
self.memory.append((state, action, reward, next_state, done))
|
||
|
||
def update(self):
|
||
if len(self.memory) < self.batch_size:
|
||
return {}
|
||
|
||
batch = random.sample(self.memory, self.batch_size)
|
||
states, actions, rewards, next_states, dones = zip(*batch)
|
||
|
||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||
actions = torch.LongTensor(np.array(actions)).to(self.device) # shape: (batch, num_edges)
|
||
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
|
||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
||
|
||
# Q值: (batch, num_edges, num_actions_per_edge)
|
||
current_q_all = self.policy_net(states)
|
||
# 选择每条边对应动作的Q值: (batch, num_edges)
|
||
current_q = current_q_all.gather(2, actions.unsqueeze(2)).squeeze(2)
|
||
# 平均所有边的Q值作为状态价值
|
||
current_q = current_q.mean(dim=1, keepdim=True)
|
||
|
||
with torch.no_grad():
|
||
next_q_all = self.target_net(next_states)
|
||
# 每条边选最大Q值,然后平均
|
||
next_q = next_q_all.max(dim=2)[0].mean(dim=1, keepdim=True)
|
||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
||
|
||
loss = nn.MSELoss()(current_q, target_q)
|
||
|
||
self.optimizer.zero_grad()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||
self.optimizer.step()
|
||
|
||
return {"loss": loss.item()}
|
||
|
||
def update_target_network(self):
|
||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||
|
||
def save(self, path: str):
|
||
torch.save({
|
||
'policy_net': self.policy_net.state_dict(),
|
||
'target_net': self.target_net.state_dict(),
|
||
'optimizer': self.optimizer.state_dict(),
|
||
}, path)
|
||
|
||
def load(self, path: str):
|
||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||
self.target_net.load_state_dict(checkpoint['target_net'])
|
||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|