ctm-dqn/training/train_dqn.py

229 lines
8.4 KiB
Python

"""
DQN训练脚本 - SUMO VSL环境
"""
import os
import sys
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from agents.dqn_agent import DQNAgent
from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "dqn")
train_config = get_training_config(config)
start_episode = 1
_, checkpoint_dir, log_dir = resolve_run_dirs(
"dqn",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
logger = TrainingLogger(log_dir, "dqn")
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
# DQN使用单个网络处理所有边
num_edges = env.num_edges
num_actions_per_edge = env.action_dim
print("=" * 70)
print("DQN训练 - SUMO VSL环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 控制边数: {num_edges}")
print(f" 每边动作数: {num_actions_per_edge}")
print(f" Episode步数: {env.episode_length}")
print()
# 创建单个DQN agent
agent = DQNAgent(
state_dim=state_dim,
num_edges=num_edges,
num_actions_per_edge=num_actions_per_edge,
hidden_dim=agent_config.get("hidden_dim", 256),
learning_rate=agent_config.get("learning_rate", 1e-3),
gamma=agent_config.get("gamma", 0.99),
epsilon_start=agent_config.get("epsilon_start", 1.0),
epsilon_end=agent_config.get("epsilon_end", 0.01),
epsilon_decay=agent_config.get("epsilon_decay", 200),
buffer_size=agent_config.get("buffer_size", 10000),
batch_size=agent_config.get("batch_size", 64),
target_update=agent_config.get("target_update", 10),
device=agent_config.get("device", "cuda")
)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_stds = []
episode_hard_brakes = []
losses = []
best_reward = -float("inf")
print("开始训练...\n")
try:
for episode in range(start_episode, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_speed_std = 0
episode_r_flow = 0
episode_r_var = 0
episode_r_brake = 0
episode_r_penalty = 0
episode_brakes = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
# 单个agent选择所有边的动作
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
# 存储转换并更新
agent.store_transition(state, action, reward, next_state, done)
train_stats = agent.update()
if train_stats:
losses.append(train_stats["loss"])
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_std += info["speed_std"] * 3.6
episode_r_flow += info["r_flow"]
episode_r_var += info["r_var"]
episode_r_brake += info["r_brake"]
episode_r_penalty += info["r_penalty"]
episode_brakes += info["num_hard_brakes"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
if episode % agent.target_update == 0:
agent.update_target_network()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_std = episode_speed_std / max(step, 1)
avg_r_flow = episode_r_flow / max(step, 1)
avg_r_var = episode_r_var / max(step, 1)
avg_r_brake = episode_r_brake / max(step, 1)
avg_r_penalty = episode_r_penalty / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_stds.append(avg_speed_std)
episode_hard_brakes.append(episode_brakes)
loss_val = np.mean(losses[-100:]) if losses else None
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_std=avg_speed_std,
r_flow=avg_r_flow,
r_var=avg_r_var,
r_brake=avg_r_brake,
r_penalty=avg_r_penalty,
hard_brakes=episode_brakes,
value_loss=loss_val,
)
episode_summary = {
"episode": episode,
"reward": float(episode_reward),
"avg_throughput": float(avg_tp),
"avg_mean_speed_kmh": float(avg_speed),
"avg_speed_std_kmh": float(avg_speed_std),
"avg_r_flow": float(avg_r_flow),
"avg_r_var": float(avg_r_var),
"avg_r_brake": float(avg_r_brake),
"avg_r_penalty": float(avg_r_penalty),
"hard_brakes": int(episode_brakes),
"loss": float(loss_val) if loss_val is not None else None,
}
save_training_episode_artifacts(
log_dir=log_dir,
episode=episode,
episode_metrics=env.episode_metrics,
control_edges=env.control_edges,
summary=episode_summary,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Speed Std: {avg_speed_std:.2f} km/h")
print(f" R(flow/var/brake/pen): {avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}")
if losses:
print(f" Loss: {np.mean(losses[-100:]):.4f}")
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
except KeyboardInterrupt:
print("\n训练被中断")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
# 绘制训练曲线
plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)