ctm-dqn/training/train_ppo.py

277 lines
11 KiB
Python

"""
鍩轰簬 SUMO+TraCI 鐨?PPO 璁粌鑴氭湰
浣跨敤寰浠跨湡鐜璁粌 VSL 鎺у埗绛栫暐
"""
import os
import sys
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
import torch
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals
from agents.ppo_agent import PPOAgent
from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""Train PPO on the SUMO VSL environment."""
# 鍔犺浇閰嶇疆
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "ppo")
train_config = get_training_config(config)
base_seed = resolve_base_seed(train_config)
set_global_seed(base_seed)
start_episode = 1
resolved_run_timestamp, checkpoint_dir, log_dir = resolve_run_dirs(
"ppo",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
runtime_config["runtime"]["evaluation_mode"] = False
runtime_config["runtime"]["run_timestamp"] = resolved_run_timestamp
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "ppo")
# 鍒涘缓鐜
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_controlled_edges
print("=" * 70)
print("PPO training - SUMO+TraCI VSL")
print("=" * 70)
print(f" State dim: {state_dim}")
print(f" Action dims: {action_dims}")
print(f" Episode length: {env.episode_length}")
print(f" Control interval: {env.control_interval}s")
print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}")
print(f" Device: {agent_config.get('device', 'cuda')}")
print()
# 鍒涘缓 PPO 鏅鸿兘浣?
agent = PPOAgent(
state_dim=state_dim,
action_dims=action_dims,
hidden_layers=agent_config.get("hidden_layers", [512, 256]),
learning_rate=agent_config.get("learning_rate", 3e-4),
gamma=agent_config.get("gamma", 0.99),
gae_lambda=agent_config.get("gae_lambda", 0.95),
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
value_coef=agent_config.get("value_coef", 0.5),
entropy_coef=agent_config.get("entropy_coef", 0.02),
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
ppo_epochs=agent_config.get("ppo_epochs", 10),
minibatch_size=agent_config.get("batch_size", 64),
device=agent_config.get("device", "cuda"),
lr_schedule=agent_config.get("lr_schedule", "cosine"),
total_episodes=train_config["num_episodes"],
)
# 璁粌鍙傛暟
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
# 缁熻鍙橀噺
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_variance_norms = []
episode_ttc_risks = []
policy_losses = []
value_losses = []
entropies = []
best_reward = -float("inf")
print("Starting training...\n")
try:
for episode in range(start_episode, num_episodes + 1):
# 姣忎釜 episode 浣跨敤涓嶅悓 seed 寮曞叆闅忔満鎬?
seed = derive_seed(base_seed, episode)
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_speed_variance_norm = 0.0
episode_reward_components = init_reward_component_totals()
episode_ttc_risk = 0.0
done = False
step = 0
pbar = tqdm(
total=env.episode_length,
desc=f"Ep {episode}/{num_episodes}",
leave=False,
)
while not done:
action, log_prob, value = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, value, log_prob, done)
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_variance_norm += info["speed_variance_norm"]
for column in REWARD_COMPONENT_COLUMNS:
episode_reward_components[column] += float(info.get(column, 0.0))
episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0))
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
# GAE 璁$畻鍜岀瓥鐣ユ洿鏂?
if done:
next_value = 0.0
else:
with torch.no_grad():
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(agent.device)
next_value = agent.policy.get_value(next_state_tensor).item()
train_stats = agent.update(next_value)
# 璁板綍缁熻
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1)
avg_reward_components = average_reward_components(episode_reward_components, step)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_variance_norms.append(avg_speed_variance_norm)
episode_ttc_risks.append(episode_ttc_risk)
if train_stats:
policy_losses.append(train_stats["policy_loss"])
value_losses.append(train_stats["value_loss"])
entropies.append(train_stats["entropy"])
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_variance_norm=avg_speed_variance_norm,
reward_components=avg_reward_components,
ttc_risk=episode_ttc_risk,
policy_loss=train_stats["policy_loss"],
value_loss=train_stats["value_loss"],
entropy=train_stats["entropy"],
)
else:
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_variance_norm=avg_speed_variance_norm,
reward_components=avg_reward_components,
ttc_risk=episode_ttc_risk,
)
# 淇濆瓨鏈€浣虫ā鍨?
episode_summary = {
"episode": episode,
"reward": float(episode_reward),
"avg_throughput": float(avg_tp),
"avg_mean_speed_kmh": float(avg_speed),
"avg_speed_variance_norm": float(avg_speed_variance_norm),
"ttc_risk": float(episode_ttc_risk),
}
for column, value in avg_reward_components.items():
episode_summary[f"avg_{column}"] = float(value)
if train_stats:
episode_summary.update(
policy_loss=float(train_stats["policy_loss"]),
value_loss=float(train_stats["value_loss"]),
entropy=float(train_stats["entropy"]),
)
save_training_episode_artifacts(
log_dir=log_dir,
episode=episode,
episode_metrics=env.episode_metrics,
control_edges=env.control_edges,
summary=episode_summary,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
# 瀹氭湡鏃ュ織
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}")
print(
" Reward Components: "
+ ", ".join(
f"{column}={avg_reward_components[column]:.3f}"
for column in REWARD_COMPONENT_COLUMNS
)
)
if train_stats:
print(f" Policy Loss: {train_stats['policy_loss']:.4f}")
print(f" Value Loss: {train_stats['value_loss']:.4f}")
print(f" Entropy: {train_stats['entropy']:.4f}")
# 瀹氭湡淇濆瓨
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
except KeyboardInterrupt:
print("\nTraining interrupted, saving current model...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
# 鏈€缁堜繚瀛?
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
# 缁樺埗璁粌鏇茬嚎
plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_variance_norms, episode_ttc_risks,
policy_losses, value_losses,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("Training complete")
print(f" Best reward: {best_reward:.2f}")
print(f" Model dir: {checkpoint_dir}")
print(f" Log dir: {log_dir}")
print("=" * 70)