ctm-dqn/dqn_agent.py

122 lines
4.3 KiB
Python

"""
DQN Agent for SUMO VSL Control
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
class DQNNetwork(nn.Module):
def __init__(self, state_dim: int, num_actions: int, hidden_dim: int = 256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_actions)
)
def forward(self, x):
return self.network(x)
class DQNAgent:
def __init__(
self,
state_dim: int,
num_actions: int,
hidden_dim: int = 256,
learning_rate: float = 1e-3,
gamma: float = 0.99,
epsilon_start: float = 1.0,
epsilon_end: float = 0.01,
epsilon_decay: int = 200,
buffer_size: int = 10000,
batch_size: int = 64,
target_update: int = 10,
device: str = "cuda"
):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.num_actions = num_actions
self.gamma = gamma
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update = target_update
self.policy_net = DQNNetwork(state_dim, num_actions, hidden_dim).to(self.device)
self.target_net = DQNNetwork(state_dim, num_actions, hidden_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
self.memory = deque(maxlen=buffer_size)
self.steps_done = 0
def select_action(self, state: np.ndarray, deterministic: bool = False):
epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
np.exp(-1. * self.steps_done / self.epsilon_decay)
self.steps_done += 1
if deterministic or random.random() > epsilon:
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state_tensor)
return q_values.argmax().item()
else:
return random.randrange(self.num_actions)
def store_transition(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def update(self):
if len(self.memory) < self.batch_size:
return {}
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
current_q = self.policy_net(states).gather(1, actions)
with torch.no_grad():
next_q = self.target_net(next_states).max(1)[0].unsqueeze(1)
target_q = rewards + (1 - dones) * self.gamma * next_q
loss = nn.MSELoss()(current_q, target_q)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
self.optimizer.step()
return {"loss": loss.item()}
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def save(self, path: str):
torch.save({
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
}, path)
def load(self, path: str):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])