316 lines
12 KiB
Python
316 lines
12 KiB
Python
"""
|
||
基于 SUMO+TraCI 的 PPO 训练脚本
|
||
使用微观仿真环境训练 VSL 控制策略
|
||
"""
|
||
import os
|
||
import sys
|
||
import yaml
|
||
import numpy as np
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
from tqdm import tqdm
|
||
import torch
|
||
|
||
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
|
||
from ppo_agent import PPOAgent
|
||
from training_logger import TrainingLogger
|
||
|
||
|
||
def train_sumo_ppo(resume_checkpoint=None):
|
||
"""SUMO 环境下的 PPO 训练主函数"""
|
||
# 加载配置
|
||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
agent_config = config.get("agent", {})
|
||
train_config = config["training"]
|
||
|
||
# 处理resume
|
||
start_episode = 1
|
||
if resume_checkpoint:
|
||
checkpoint_dir = os.path.dirname(resume_checkpoint)
|
||
log_dir = checkpoint_dir.replace("checkpoints_sumo_vsl", "logs_sumo_vsl")
|
||
# 从文件名提取episode数
|
||
filename = os.path.basename(resume_checkpoint)
|
||
if "ep" in filename:
|
||
start_episode = int(filename.split("ep")[1].split("_")[0].split(".")[0]) + 1
|
||
print(f"从 episode {start_episode} 继续训练")
|
||
else:
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
checkpoint_dir = os.path.join(train_config["checkpoint_dir"], timestamp)
|
||
log_dir = os.path.join(train_config["log_dir"], timestamp)
|
||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||
yaml.dump(config, f)
|
||
|
||
logger = TrainingLogger(log_dir, "ppo", resume=bool(resume_checkpoint))
|
||
|
||
# 创建环境
|
||
env = SUMOEdgeVSLEnvironment(config)
|
||
|
||
state_dim = env.state_dim
|
||
action_dims = [env.action_dim] * env.num_edges
|
||
|
||
print("=" * 70)
|
||
print("PPO 训练 - SUMO+TraCI VSL 环境")
|
||
print("=" * 70)
|
||
print(f" 状态维度: {state_dim}")
|
||
print(f" 控制边数: {env.num_edges}")
|
||
print(f" 每边动作数: {env.action_dim}")
|
||
print(f" Episode 步数: {env.episode_length}")
|
||
print(f" 控制间隔: {env.control_interval}s")
|
||
print(f" 隐藏层: {agent_config.get('hidden_layers', [512, 256])}")
|
||
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
|
||
print(f" 设备: {agent_config.get('device', 'cuda')}")
|
||
print()
|
||
|
||
# 创建 PPO 智能体
|
||
agent = PPOAgent(
|
||
state_dim=state_dim,
|
||
action_dims=action_dims,
|
||
hidden_layers=agent_config.get("hidden_layers", [512, 256]),
|
||
learning_rate=agent_config.get("learning_rate", 3e-4),
|
||
gamma=agent_config.get("gamma", 0.99),
|
||
gae_lambda=agent_config.get("gae_lambda", 0.95),
|
||
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
|
||
value_coef=agent_config.get("value_coef", 0.5),
|
||
entropy_coef=agent_config.get("entropy_coef", 0.02),
|
||
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
|
||
ppo_epochs=agent_config.get("ppo_epochs", 10),
|
||
minibatch_size=agent_config.get("batch_size", 64),
|
||
device=agent_config.get("device", "cuda"),
|
||
)
|
||
|
||
# 加载checkpoint
|
||
if resume_checkpoint:
|
||
agent.load(resume_checkpoint)
|
||
print(f"已加载模型: {resume_checkpoint}")
|
||
|
||
# 训练参数
|
||
num_episodes = train_config["num_episodes"]
|
||
save_freq = train_config.get("save_freq", 50)
|
||
log_freq = train_config.get("log_freq", 10)
|
||
base_seed = train_config.get("random_seed", 42)
|
||
|
||
# 统计变量
|
||
episode_rewards = []
|
||
episode_throughputs = []
|
||
episode_mean_speeds = []
|
||
policy_losses = []
|
||
value_losses = []
|
||
entropies = []
|
||
best_reward = -float("inf")
|
||
|
||
# 加载历史训练数据
|
||
if resume_checkpoint:
|
||
import csv
|
||
log_file = os.path.join(log_dir, "ppo_training_log.csv")
|
||
if os.path.exists(log_file):
|
||
with open(log_file, "r", encoding="utf-8") as f:
|
||
reader = csv.DictReader(f)
|
||
for row in reader:
|
||
episode_rewards.append(float(row["reward"]))
|
||
episode_throughputs.append(float(row["throughput"]))
|
||
episode_mean_speeds.append(float(row["mean_speed"]))
|
||
if row["policy_loss"]:
|
||
policy_losses.append(float(row["policy_loss"]))
|
||
if row["value_loss"]:
|
||
value_losses.append(float(row["value_loss"]))
|
||
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
|
||
print(f"已加载 {len(episode_rewards)} 条历史记录")
|
||
|
||
print("开始训练...\n")
|
||
|
||
try:
|
||
for episode in range(start_episode, num_episodes + 1):
|
||
# 每个 episode 使用不同 seed 引入随机性
|
||
seed = base_seed + episode
|
||
state = env.reset(seed=seed)
|
||
episode_reward = 0
|
||
episode_throughput = 0
|
||
episode_speed = 0
|
||
done = False
|
||
step = 0
|
||
|
||
pbar = tqdm(
|
||
total=env.episode_length,
|
||
desc=f"Ep {episode}/{num_episodes}",
|
||
leave=False,
|
||
)
|
||
|
||
while not done:
|
||
action, log_prob, value = agent.select_action(state, deterministic=False)
|
||
next_state, reward, done, info = env.step(action)
|
||
|
||
agent.store_transition(state, action, reward, value, log_prob, done)
|
||
|
||
episode_reward += reward
|
||
episode_throughput += info["throughput"]
|
||
episode_speed += info["mean_speed_kmh"]
|
||
state = next_state
|
||
step += 1
|
||
|
||
pbar.set_postfix(
|
||
r=f"{episode_reward:.1f}",
|
||
tp=f"{info['throughput']:.0f}",
|
||
v=f"{info['mean_speed_kmh']:.1f}",
|
||
)
|
||
pbar.update(1)
|
||
|
||
pbar.close()
|
||
|
||
# GAE 计算和策略更新
|
||
if done:
|
||
next_value = 0.0
|
||
else:
|
||
with torch.no_grad():
|
||
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device)
|
||
next_value = agent.policy.get_value(next_state_tensor).item()
|
||
|
||
train_stats = agent.update(next_value)
|
||
|
||
# 记录统计
|
||
avg_tp = episode_throughput / max(step, 1)
|
||
avg_speed = episode_speed / max(step, 1)
|
||
episode_rewards.append(episode_reward)
|
||
episode_throughputs.append(avg_tp)
|
||
episode_mean_speeds.append(avg_speed)
|
||
|
||
if train_stats:
|
||
policy_losses.append(train_stats["policy_loss"])
|
||
value_losses.append(train_stats["value_loss"])
|
||
entropies.append(train_stats["entropy"])
|
||
logger.log(episode, episode_reward, avg_tp, avg_speed,
|
||
train_stats["policy_loss"], train_stats["value_loss"], train_stats["entropy"])
|
||
else:
|
||
logger.log(episode, episode_reward, avg_tp, avg_speed)
|
||
|
||
# 保存最佳模型
|
||
if episode_reward > best_reward:
|
||
best_reward = episode_reward
|
||
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
|
||
|
||
# 定期日志
|
||
if episode % log_freq == 0:
|
||
recent_rewards = episode_rewards[-log_freq:]
|
||
print(f"\nEpisode {episode}/{num_episodes}")
|
||
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
|
||
print(f" Throughput: {avg_tp:.1f} veh/h")
|
||
print(f" Mean Speed: {avg_speed:.1f} km/h")
|
||
if train_stats:
|
||
print(f" Policy Loss: {train_stats['policy_loss']:.4f}")
|
||
print(f" Value Loss: {train_stats['value_loss']:.4f}")
|
||
print(f" Entropy: {train_stats['entropy']:.4f}")
|
||
|
||
# 定期保存
|
||
if episode % save_freq == 0:
|
||
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n训练被中断,保存当前模型...")
|
||
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
|
||
finally:
|
||
env.close()
|
||
|
||
# 最终保存
|
||
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
|
||
|
||
# 绘制训练曲线
|
||
_plot_training_curves(
|
||
episode_rewards, episode_throughputs, episode_mean_speeds,
|
||
policy_losses, value_losses,
|
||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||
)
|
||
|
||
print("=" * 70)
|
||
print("训练完成!")
|
||
print(f" 最佳奖励: {best_reward:.2f}")
|
||
print(f" 模型目录: {checkpoint_dir}")
|
||
print(f" 日志目录: {log_dir}")
|
||
print("=" * 70)
|
||
|
||
|
||
def _plot_training_curves(
|
||
rewards, throughputs, mean_speeds, policy_losses, value_losses,
|
||
save_path: str,
|
||
):
|
||
"""绘制训练曲线"""
|
||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||
|
||
window = 20
|
||
|
||
# Rewards
|
||
axes[0, 0].plot(rewards, alpha=0.4, color="blue")
|
||
if len(rewards) > window:
|
||
ma = np.convolve(rewards, np.ones(window) / window, mode="valid")
|
||
axes[0, 0].plot(range(window - 1, len(rewards)), ma, "r-", linewidth=2)
|
||
axes[0, 0].set_xlabel("Episode")
|
||
axes[0, 0].set_ylabel("Total Reward")
|
||
axes[0, 0].set_title("Episode Reward")
|
||
axes[0, 0].grid(True, alpha=0.3)
|
||
|
||
# Throughput
|
||
axes[0, 1].plot(throughputs, alpha=0.4, color="green")
|
||
if len(throughputs) > window:
|
||
ma = np.convolve(throughputs, np.ones(window) / window, mode="valid")
|
||
axes[0, 1].plot(range(window - 1, len(throughputs)), ma, "r-", linewidth=2)
|
||
axes[0, 1].set_xlabel("Episode")
|
||
axes[0, 1].set_ylabel("Avg Throughput (veh/h)")
|
||
axes[0, 1].set_title("Throughput")
|
||
axes[0, 1].grid(True, alpha=0.3)
|
||
|
||
# Mean Speed
|
||
axes[0, 2].plot(mean_speeds, alpha=0.4, color="orange")
|
||
if len(mean_speeds) > window:
|
||
ma = np.convolve(mean_speeds, np.ones(window) / window, mode="valid")
|
||
axes[0, 2].plot(range(window - 1, len(mean_speeds)), ma, "r-", linewidth=2)
|
||
axes[0, 2].set_xlabel("Episode")
|
||
axes[0, 2].set_ylabel("Mean Speed (km/h)")
|
||
axes[0, 2].set_title("Mean Speed")
|
||
axes[0, 2].grid(True, alpha=0.3)
|
||
|
||
# Policy Loss
|
||
if policy_losses:
|
||
axes[1, 0].plot(policy_losses, "b-", alpha=0.6)
|
||
axes[1, 0].set_xlabel("Episode")
|
||
axes[1, 0].set_ylabel("Policy Loss")
|
||
axes[1, 0].set_title("Policy Loss")
|
||
axes[1, 0].grid(True, alpha=0.3)
|
||
|
||
# Value Loss
|
||
if value_losses:
|
||
axes[1, 1].plot(value_losses, "r-", alpha=0.6)
|
||
axes[1, 1].set_xlabel("Episode")
|
||
axes[1, 1].set_ylabel("Value Loss")
|
||
axes[1, 1].set_title("Value Loss")
|
||
axes[1, 1].grid(True, alpha=0.3)
|
||
|
||
# Summary text
|
||
axes[1, 2].axis("off")
|
||
summary = (
|
||
f"Training Summary\n"
|
||
f"{'='*30}\n"
|
||
f"Episodes: {len(rewards)}\n"
|
||
f"Best Reward: {max(rewards):.2f}\n"
|
||
f"Final Avg Reward: {np.mean(rewards[-20:]):.2f}\n"
|
||
f"Final Avg Throughput: {np.mean(throughputs[-20:]):.1f}\n"
|
||
f"Final Avg Speed: {np.mean(mean_speeds[-20:]):.1f} km/h"
|
||
)
|
||
axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace",
|
||
verticalalignment="center", transform=axes[1, 2].transAxes)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||
print(f"训练曲线已保存: {save_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--resume", type=str, help="从checkpoint继续训练,例如: checkpoints_sumo_vsl/xxx/model_ep500.pt")
|
||
args = parser.parse_args()
|
||
train_sumo_ppo(resume_checkpoint=args.resume)
|