ctm-dqn/train_sumo_ppo.py

316 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基于 SUMO+TraCI 的 PPO 训练脚本
使用微观仿真环境训练 VSL 控制策略
"""
import os
import sys
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
import torch
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
from ppo_agent import PPOAgent
from training_logger import TrainingLogger
def train_sumo_ppo(resume_checkpoint=None):
"""SUMO 环境下的 PPO 训练主函数"""
# 加载配置
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = config.get("agent", {})
train_config = config["training"]
# 处理resume
start_episode = 1
if resume_checkpoint:
checkpoint_dir = os.path.dirname(resume_checkpoint)
log_dir = checkpoint_dir.replace("checkpoints_sumo_vsl", "logs_sumo_vsl")
# 从文件名提取episode数
filename = os.path.basename(resume_checkpoint)
if "ep" in filename:
start_episode = int(filename.split("ep")[1].split("_")[0].split(".")[0]) + 1
print(f"从 episode {start_episode} 继续训练")
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join(train_config["checkpoint_dir"], timestamp)
log_dir = os.path.join(train_config["log_dir"], timestamp)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(config, f)
logger = TrainingLogger(log_dir, "ppo", resume=bool(resume_checkpoint))
# 创建环境
env = SUMOEdgeVSLEnvironment(config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_edges
print("=" * 70)
print("PPO 训练 - SUMO+TraCI VSL 环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 控制边数: {env.num_edges}")
print(f" 每边动作数: {env.action_dim}")
print(f" Episode 步数: {env.episode_length}")
print(f" 控制间隔: {env.control_interval}s")
print(f" 隐藏层: {agent_config.get('hidden_layers', [512, 256])}")
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
print(f" 设备: {agent_config.get('device', 'cuda')}")
print()
# 创建 PPO 智能体
agent = PPOAgent(
state_dim=state_dim,
action_dims=action_dims,
hidden_layers=agent_config.get("hidden_layers", [512, 256]),
learning_rate=agent_config.get("learning_rate", 3e-4),
gamma=agent_config.get("gamma", 0.99),
gae_lambda=agent_config.get("gae_lambda", 0.95),
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
value_coef=agent_config.get("value_coef", 0.5),
entropy_coef=agent_config.get("entropy_coef", 0.02),
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
ppo_epochs=agent_config.get("ppo_epochs", 10),
minibatch_size=agent_config.get("batch_size", 64),
device=agent_config.get("device", "cuda"),
)
# 加载checkpoint
if resume_checkpoint:
agent.load(resume_checkpoint)
print(f"已加载模型: {resume_checkpoint}")
# 训练参数
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
# 统计变量
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
policy_losses = []
value_losses = []
entropies = []
best_reward = -float("inf")
# 加载历史训练数据
if resume_checkpoint:
import csv
log_file = os.path.join(log_dir, "ppo_training_log.csv")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
episode_rewards.append(float(row["reward"]))
episode_throughputs.append(float(row["throughput"]))
episode_mean_speeds.append(float(row["mean_speed"]))
if row["policy_loss"]:
policy_losses.append(float(row["policy_loss"]))
if row["value_loss"]:
value_losses.append(float(row["value_loss"]))
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
print(f"已加载 {len(episode_rewards)} 条历史记录")
print("开始训练...\n")
try:
for episode in range(start_episode, num_episodes + 1):
# 每个 episode 使用不同 seed 引入随机性
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
done = False
step = 0
pbar = tqdm(
total=env.episode_length,
desc=f"Ep {episode}/{num_episodes}",
leave=False,
)
while not done:
action, log_prob, value = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, value, log_prob, done)
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
# GAE 计算和策略更新
if done:
next_value = 0.0
else:
with torch.no_grad():
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device)
next_value = agent.policy.get_value(next_state_tensor).item()
train_stats = agent.update(next_value)
# 记录统计
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
if train_stats:
policy_losses.append(train_stats["policy_loss"])
value_losses.append(train_stats["value_loss"])
entropies.append(train_stats["entropy"])
logger.log(episode, episode_reward, avg_tp, avg_speed,
train_stats["policy_loss"], train_stats["value_loss"], train_stats["entropy"])
else:
logger.log(episode, episode_reward, avg_tp, avg_speed)
# 保存最佳模型
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
# 定期日志
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
if train_stats:
print(f" Policy Loss: {train_stats['policy_loss']:.4f}")
print(f" Value Loss: {train_stats['value_loss']:.4f}")
print(f" Entropy: {train_stats['entropy']:.4f}")
# 定期保存
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
except KeyboardInterrupt:
print("\n训练被中断,保存当前模型...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
# 最终保存
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
# 绘制训练曲线
_plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds,
policy_losses, value_losses,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)
def _plot_training_curves(
rewards, throughputs, mean_speeds, policy_losses, value_losses,
save_path: str,
):
"""绘制训练曲线"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
window = 20
# Rewards
axes[0, 0].plot(rewards, alpha=0.4, color="blue")
if len(rewards) > window:
ma = np.convolve(rewards, np.ones(window) / window, mode="valid")
axes[0, 0].plot(range(window - 1, len(rewards)), ma, "r-", linewidth=2)
axes[0, 0].set_xlabel("Episode")
axes[0, 0].set_ylabel("Total Reward")
axes[0, 0].set_title("Episode Reward")
axes[0, 0].grid(True, alpha=0.3)
# Throughput
axes[0, 1].plot(throughputs, alpha=0.4, color="green")
if len(throughputs) > window:
ma = np.convolve(throughputs, np.ones(window) / window, mode="valid")
axes[0, 1].plot(range(window - 1, len(throughputs)), ma, "r-", linewidth=2)
axes[0, 1].set_xlabel("Episode")
axes[0, 1].set_ylabel("Avg Throughput (veh/h)")
axes[0, 1].set_title("Throughput")
axes[0, 1].grid(True, alpha=0.3)
# Mean Speed
axes[0, 2].plot(mean_speeds, alpha=0.4, color="orange")
if len(mean_speeds) > window:
ma = np.convolve(mean_speeds, np.ones(window) / window, mode="valid")
axes[0, 2].plot(range(window - 1, len(mean_speeds)), ma, "r-", linewidth=2)
axes[0, 2].set_xlabel("Episode")
axes[0, 2].set_ylabel("Mean Speed (km/h)")
axes[0, 2].set_title("Mean Speed")
axes[0, 2].grid(True, alpha=0.3)
# Policy Loss
if policy_losses:
axes[1, 0].plot(policy_losses, "b-", alpha=0.6)
axes[1, 0].set_xlabel("Episode")
axes[1, 0].set_ylabel("Policy Loss")
axes[1, 0].set_title("Policy Loss")
axes[1, 0].grid(True, alpha=0.3)
# Value Loss
if value_losses:
axes[1, 1].plot(value_losses, "r-", alpha=0.6)
axes[1, 1].set_xlabel("Episode")
axes[1, 1].set_ylabel("Value Loss")
axes[1, 1].set_title("Value Loss")
axes[1, 1].grid(True, alpha=0.3)
# Summary text
axes[1, 2].axis("off")
summary = (
f"Training Summary\n"
f"{'='*30}\n"
f"Episodes: {len(rewards)}\n"
f"Best Reward: {max(rewards):.2f}\n"
f"Final Avg Reward: {np.mean(rewards[-20:]):.2f}\n"
f"Final Avg Throughput: {np.mean(throughputs[-20:]):.1f}\n"
f"Final Avg Speed: {np.mean(mean_speeds[-20:]):.1f} km/h"
)
axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace",
verticalalignment="center", transform=axes[1, 2].transAxes)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"训练曲线已保存: {save_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--resume", type=str, help="从checkpoint继续训练例如: checkpoints_sumo_vsl/xxx/model_ep500.pt")
args = parser.parse_args()
train_sumo_ppo(resume_checkpoint=args.resume)