ctm-dqn/train_sumo_dqn.py

215 lines
7.3 KiB
Python

"""
DQN训练脚本 - SUMO VSL环境
"""
import os
import sys
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
from dqn_agent import DQNAgent
from training_logger import TrainingLogger
def train_sumo_dqn():
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = config.get("agent", {})
train_config = config["training"]
start_episode = 1
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints_sumo_dqn", timestamp)
log_dir = os.path.join("logs_sumo_dqn", timestamp)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(config, f)
logger = TrainingLogger(log_dir, "dqn")
env = SUMOEdgeVSLEnvironment(config)
state_dim = env.state_dim
# DQN使用单个网络处理所有边
num_edges = env.num_edges
num_actions_per_edge = 5
print("=" * 70)
print("DQN训练 - SUMO VSL环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 控制边数: {num_edges}")
print(f" 每边动作数: {num_actions_per_edge}")
print(f" Episode步数: {env.episode_length}")
print()
# 创建单个DQN agent
agent = DQNAgent(
state_dim=state_dim,
num_edges=num_edges,
num_actions_per_edge=num_actions_per_edge,
hidden_dim=agent_config.get("hidden_dim", 256),
learning_rate=agent_config.get("learning_rate", 1e-3),
gamma=agent_config.get("gamma", 0.99),
epsilon_start=1.0,
epsilon_end=0.01,
epsilon_decay=200,
buffer_size=10000,
batch_size=agent_config.get("batch_size", 64),
target_update=10,
device=agent_config.get("device", "cuda")
)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_hard_brakes = []
losses = []
best_reward = -float("inf")
print("开始训练...\n")
try:
for episode in range(start_episode, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_brakes = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
# 单个agent选择所有边的动作
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
# 存储转换并更新
agent.store_transition(state, action, reward, next_state, done)
train_stats = agent.update()
if train_stats:
losses.append(train_stats["loss"])
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_brakes += info["num_hard_brakes"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
if episode % 10 == 0:
agent.update_target_network()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_hard_brakes.append(episode_brakes)
loss_val = np.mean(losses[-100:]) if losses else None
logger.log(episode, episode_reward, avg_tp, avg_speed, episode_brakes, value_loss=loss_val)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
if losses:
print(f" Loss: {np.mean(losses[-100:]):.4f}")
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
except KeyboardInterrupt:
print("\n训练被中断")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
# 绘制训练曲线
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes[0, 0].plot(episode_rewards, alpha=0.6)
window = 20
if len(episode_rewards) > window:
ma = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
axes[0, 0].plot(range(window-1, len(episode_rewards)), ma, 'r-', linewidth=2)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Total Reward')
axes[0, 0].set_title('DQN Training Reward')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].plot(episode_throughputs, 'g-', alpha=0.6)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Avg Throughput (veh/h)')
axes[0, 1].set_title('Throughput')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 2].plot(episode_mean_speeds, 'orange', alpha=0.6)
axes[0, 2].set_xlabel('Episode')
axes[0, 2].set_ylabel('Mean Speed (km/h)')
axes[0, 2].set_title('Mean Speed')
axes[0, 2].grid(True, alpha=0.3)
axes[1, 0].plot(episode_hard_brakes, 'r-', alpha=0.6)
if len(episode_hard_brakes) > window:
ma = np.convolve(episode_hard_brakes, np.ones(window)/window, mode='valid')
axes[1, 0].plot(range(window-1, len(episode_hard_brakes)), ma, 'b-', linewidth=2)
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Hard Brakes Count')
axes[1, 0].set_title('Hard Brakes')
axes[1, 0].grid(True, alpha=0.3)
if losses:
axes[1, 1].plot(losses, 'b-', alpha=0.6)
axes[1, 1].set_xlabel('Update Step')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Training Loss')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 2].axis('off')
plt.tight_layout()
plt.savefig(os.path.join(log_dir, "training_curves.png"), dpi=150)
print(f"训练曲线已保存: {os.path.join(log_dir, 'training_curves.png')}")
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)
if __name__ == "__main__":
train_sumo_dqn()