122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
"""
|
|
DQN Agent for SUMO VSL Control
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
from collections import deque
|
|
import random
|
|
|
|
|
|
class DQNNetwork(nn.Module):
|
|
def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = 256):
|
|
super().__init__()
|
|
self.network = nn.Sequential(
|
|
nn.Linear(state_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, num_actions)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.network(x)
|
|
|
|
|
|
class DQNAgent:
|
|
def __init__(
|
|
self,
|
|
state_dim: int,
|
|
num_actions: int,
|
|
hidden_dim: int = 256,
|
|
learning_rate: float = 1e-3,
|
|
gamma: float = 0.99,
|
|
epsilon_start: float = 1.0,
|
|
epsilon_end: float = 0.01,
|
|
epsilon_decay: int = 200,
|
|
buffer_size: int = 10000,
|
|
batch_size: int = 64,
|
|
target_update: int = 10,
|
|
device: str = "cuda"
|
|
):
|
|
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
|
self.num_actions = num_actions
|
|
self.gamma = gamma
|
|
self.epsilon_start = epsilon_start
|
|
self.epsilon_end = epsilon_end
|
|
self.epsilon_decay = epsilon_decay
|
|
self.batch_size = batch_size
|
|
self.target_update = target_update
|
|
|
|
self.policy_net = DQNNetwork(state_dim, num_actions, hidden_dim).to(self.device)
|
|
self.target_net = DQNNetwork(state_dim, num_actions, hidden_dim).to(self.device)
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
self.target_net.eval()
|
|
|
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
|
|
self.memory = deque(maxlen=buffer_size)
|
|
self.steps_done = 0
|
|
|
|
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
|
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
|
np.exp(-1. * self.steps_done / self.epsilon_decay)
|
|
self.steps_done += 1
|
|
|
|
if deterministic or random.random() > epsilon:
|
|
with torch.no_grad():
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
|
q_values = self.policy_net(state_tensor)
|
|
return q_values.argmax().item()
|
|
else:
|
|
return random.randrange(self.num_actions)
|
|
|
|
def store_transition(self, state, action, reward, next_state, done):
|
|
self.memory.append((state, action, reward, next_state, done))
|
|
|
|
def update(self):
|
|
if len(self.memory) < self.batch_size:
|
|
return {}
|
|
|
|
batch = random.sample(self.memory, self.batch_size)
|
|
states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
|
actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
|
|
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
|
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
|
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
|
|
|
|
current_q = self.policy_net(states).gather(1, actions)
|
|
|
|
with torch.no_grad():
|
|
next_q = self.target_net(next_states).max(1)[0].unsqueeze(1)
|
|
target_q = rewards + (1 - dones) * self.gamma * next_q
|
|
|
|
loss = nn.MSELoss()(current_q, target_q)
|
|
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
|
self.optimizer.step()
|
|
|
|
return {"loss": loss.item()}
|
|
|
|
def update_target_network(self):
|
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
|
|
def save(self, path: str):
|
|
torch.save({
|
|
'policy_net': self.policy_net.state_dict(),
|
|
'target_net': self.target_net.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
}, path)
|
|
|
|
def load(self, path: str):
|
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|