278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""
|
|
基于 SUMO+TraCI 的 APPO 训练脚本
|
|
使用微观仿真环境训练 VSL 控制策略
|
|
"""
|
|
import os
|
|
import sys
|
|
import copy
|
|
import yaml
|
|
import numpy as np
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
from tqdm import tqdm
|
|
import torch
|
|
|
|
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
|
|
from agents.appo_agent import APPOAgent
|
|
from utils.config import get_agent_config, get_training_config
|
|
from utils.episode_artifacts import save_training_episode_artifacts
|
|
from utils.logger import TrainingLogger
|
|
from utils.plot import plot_training_curves
|
|
from utils.run_dirs import resolve_run_dirs
|
|
|
|
|
|
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|
"""SUMO 环境下的 APPO 训练主函数"""
|
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
agent_config = get_agent_config(config, "appo")
|
|
train_config = get_training_config(config)
|
|
|
|
start_episode = 1
|
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
|
"appo",
|
|
log_dir=log_dir,
|
|
checkpoint_dir=checkpoint_dir,
|
|
run_timestamp=run_timestamp,
|
|
)
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
runtime_config = copy.deepcopy(config)
|
|
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
|
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
|
yaml.dump(runtime_config, f)
|
|
|
|
logger = TrainingLogger(log_dir, "appo")
|
|
env = SUMOEdgeVSLEnvironment(runtime_config)
|
|
|
|
state_dim = env.state_dim
|
|
action_dims = [env.action_dim] * env.num_edges
|
|
|
|
print("=" * 70)
|
|
print("APPO训练 - SUMO+TraCI VSL 环境")
|
|
print("=" * 70)
|
|
print(f" 状态维度: {state_dim}")
|
|
print(f" 动作空间: {action_dims}")
|
|
print(f" Episode 步数: {env.episode_length}")
|
|
print(f" 控制间隔: {env.control_interval}s")
|
|
print(f" 隐藏维度: {agent_config.get('hidden_dim', 128)}")
|
|
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
|
|
print(f" 设备: {agent_config.get('device', 'cuda')}")
|
|
print()
|
|
|
|
agent = APPOAgent(
|
|
state_dim=state_dim,
|
|
action_dims=action_dims,
|
|
edge_feature_dim=env.features_per_edge,
|
|
hidden_dim=agent_config.get("hidden_dim", 128),
|
|
num_heads=agent_config.get("num_heads", 4),
|
|
num_layers=agent_config.get("num_layers", 2),
|
|
learning_rate=agent_config.get("learning_rate", 3e-4),
|
|
gamma=agent_config.get("gamma", 0.99),
|
|
gae_lambda=agent_config.get("gae_lambda", 0.95),
|
|
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
|
|
value_coef=agent_config.get("value_coef", 0.5),
|
|
entropy_coef=agent_config.get("entropy_coef", 0.02),
|
|
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
|
|
ppo_epochs=agent_config.get("ppo_epochs", 10),
|
|
minibatch_size=agent_config.get("batch_size", 64),
|
|
device=agent_config.get("device", "cuda"),
|
|
lr_schedule=agent_config.get("lr_schedule", "cosine"),
|
|
total_episodes=train_config["num_episodes"]
|
|
)
|
|
|
|
# 训练参数
|
|
num_episodes = train_config["num_episodes"]
|
|
save_freq = train_config.get("save_freq", 50)
|
|
log_freq = train_config.get("log_freq", 10)
|
|
base_seed = train_config.get("random_seed", 42)
|
|
|
|
# 统计变量
|
|
episode_rewards = []
|
|
episode_throughputs = []
|
|
episode_mean_speeds = []
|
|
episode_speed_stds = []
|
|
episode_hard_brakes = []
|
|
policy_losses = []
|
|
value_losses = []
|
|
entropies = []
|
|
best_reward = -float("inf")
|
|
|
|
print("开始训练...\n")
|
|
|
|
try:
|
|
for episode in range(start_episode, num_episodes + 1):
|
|
# 每个 episode 使用不同 seed 引入随机性
|
|
seed = base_seed + episode
|
|
state = env.reset(seed=seed)
|
|
episode_reward = 0
|
|
episode_throughput = 0
|
|
episode_speed = 0
|
|
episode_speed_std = 0
|
|
episode_r_flow = 0
|
|
episode_r_var = 0
|
|
episode_r_brake = 0
|
|
episode_r_penalty = 0
|
|
episode_brakes = 0
|
|
done = False
|
|
step = 0
|
|
|
|
pbar = tqdm(
|
|
total=env.episode_length,
|
|
desc=f"Ep {episode}/{num_episodes}",
|
|
leave=False,
|
|
)
|
|
|
|
while not done:
|
|
action, log_prob, value = agent.select_action(state, deterministic=False)
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
agent.store_transition(state, action, reward, value, log_prob, done)
|
|
|
|
episode_reward += reward
|
|
episode_throughput += info["throughput"]
|
|
episode_speed += info["mean_speed_kmh"]
|
|
episode_speed_std += info["speed_std"] * 3.6
|
|
episode_r_flow += info["r_flow"]
|
|
episode_r_var += info["r_var"]
|
|
episode_r_brake += info["r_brake"]
|
|
episode_r_penalty += info["r_penalty"]
|
|
episode_brakes += info["num_hard_brakes"]
|
|
state = next_state
|
|
step += 1
|
|
|
|
pbar.set_postfix(
|
|
r=f"{episode_reward:.1f}",
|
|
tp=f"{info['throughput']:.0f}",
|
|
v=f"{info['mean_speed_kmh']:.1f}",
|
|
)
|
|
pbar.update(1)
|
|
|
|
pbar.close()
|
|
|
|
# GAE 计算和策略更新
|
|
if done:
|
|
next_value = 0.0
|
|
else:
|
|
with torch.no_grad():
|
|
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device)
|
|
next_value = agent.policy.get_value(next_state_tensor).item()
|
|
|
|
train_stats = agent.update(next_value)
|
|
|
|
# 记录统计
|
|
avg_tp = episode_throughput / max(step, 1)
|
|
avg_speed = episode_speed / max(step, 1)
|
|
avg_speed_std = episode_speed_std / max(step, 1)
|
|
avg_r_flow = episode_r_flow / max(step, 1)
|
|
avg_r_var = episode_r_var / max(step, 1)
|
|
avg_r_brake = episode_r_brake / max(step, 1)
|
|
avg_r_penalty = episode_r_penalty / max(step, 1)
|
|
episode_rewards.append(episode_reward)
|
|
episode_throughputs.append(avg_tp)
|
|
episode_mean_speeds.append(avg_speed)
|
|
episode_speed_stds.append(avg_speed_std)
|
|
episode_hard_brakes.append(episode_brakes)
|
|
|
|
if train_stats:
|
|
policy_losses.append(train_stats["policy_loss"])
|
|
value_losses.append(train_stats["value_loss"])
|
|
entropies.append(train_stats["entropy"])
|
|
logger.log(
|
|
episode, episode_reward, avg_tp, avg_speed,
|
|
speed_std=avg_speed_std,
|
|
r_flow=avg_r_flow,
|
|
r_var=avg_r_var,
|
|
r_brake=avg_r_brake,
|
|
r_penalty=avg_r_penalty,
|
|
hard_brakes=episode_brakes,
|
|
policy_loss=train_stats["policy_loss"],
|
|
value_loss=train_stats["value_loss"],
|
|
entropy=train_stats["entropy"],
|
|
)
|
|
else:
|
|
logger.log(
|
|
episode, episode_reward, avg_tp, avg_speed,
|
|
speed_std=avg_speed_std,
|
|
r_flow=avg_r_flow,
|
|
r_var=avg_r_var,
|
|
r_brake=avg_r_brake,
|
|
r_penalty=avg_r_penalty,
|
|
hard_brakes=episode_brakes,
|
|
)
|
|
|
|
# 保存最佳模型
|
|
episode_summary = {
|
|
"episode": episode,
|
|
"reward": float(episode_reward),
|
|
"avg_throughput": float(avg_tp),
|
|
"avg_mean_speed_kmh": float(avg_speed),
|
|
"avg_speed_std_kmh": float(avg_speed_std),
|
|
"avg_r_flow": float(avg_r_flow),
|
|
"avg_r_var": float(avg_r_var),
|
|
"avg_r_brake": float(avg_r_brake),
|
|
"avg_r_penalty": float(avg_r_penalty),
|
|
"hard_brakes": int(episode_brakes),
|
|
}
|
|
if train_stats:
|
|
episode_summary.update(
|
|
policy_loss=float(train_stats["policy_loss"]),
|
|
value_loss=float(train_stats["value_loss"]),
|
|
entropy=float(train_stats["entropy"]),
|
|
)
|
|
save_training_episode_artifacts(
|
|
log_dir=log_dir,
|
|
episode=episode,
|
|
episode_metrics=env.episode_metrics,
|
|
control_edges=env.control_edges,
|
|
summary=episode_summary,
|
|
)
|
|
|
|
if episode_reward > best_reward:
|
|
best_reward = episode_reward
|
|
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
|
|
|
|
# 定期日志
|
|
if episode % log_freq == 0:
|
|
recent_rewards = episode_rewards[-log_freq:]
|
|
print(f"\nEpisode {episode}/{num_episodes}")
|
|
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
|
|
print(f" Throughput: {avg_tp:.1f} veh/h")
|
|
print(f" Mean Speed: {avg_speed:.1f} km/h")
|
|
print(f" Speed Std: {avg_speed_std:.2f} km/h")
|
|
print(f" R(flow/var/brake/pen): {avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}")
|
|
if train_stats:
|
|
print(f" Policy Loss: {train_stats['policy_loss']:.4f}")
|
|
print(f" Value Loss: {train_stats['value_loss']:.4f}")
|
|
print(f" Entropy: {train_stats['entropy']:.4f}")
|
|
|
|
# 定期保存
|
|
if episode % save_freq == 0:
|
|
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
|
|
|
|
except KeyboardInterrupt:
|
|
print("\n训练被中断,保存当前模型...")
|
|
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
|
|
finally:
|
|
env.close()
|
|
|
|
# 最终保存
|
|
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
|
|
|
|
# 绘制训练曲线
|
|
plot_training_curves(
|
|
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
|
policy_losses, value_losses,
|
|
save_path=os.path.join(log_dir, "training_curves.png"),
|
|
)
|
|
|
|
print("=" * 70)
|
|
print("训练完成!")
|
|
print(f" 最佳奖励: {best_reward:.2f}")
|
|
print(f" 模型目录: {checkpoint_dir}")
|
|
print(f" 日志目录: {log_dir}")
|
|
print("=" * 70)
|