ctm-dqn/training/train_gpro.py

298 lines
12 KiB
Python

"""GRPO-inspired PPO training entrypoint for corridor VSL control."""
import os
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from agents.gpro_agent import GPROAgent
from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
def train_sumo_gpro(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""Train grouped relative PPO under the SUMO+TraCI VSL environment."""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "gpro")
train_config = get_training_config(config)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"gpro",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, "gpro")
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_edges
group_size = int(agent_config.get("group_size", 4))
print("=" * 70)
print("GPRO-PPO training - SUMO+TraCI VSL environment")
print("=" * 70)
print(f" State dim: {state_dim}")
print(f" Controlled edges: {env.num_edges}")
print(f" Actions per edge: {env.action_dim}")
print(f" Episode steps: {env.episode_length}")
print(f" Control interval: {env.control_interval}s")
print(f" Hidden layers: {agent_config.get('hidden_layers', [256, 256])}")
print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}")
print(f" Group size: {group_size}")
print(f" Group advantage coef: {agent_config.get('group_advantage_coef', 0.35)}")
print(f" Device: {agent_config.get('device', 'cuda')}")
print()
agent = GPROAgent(
state_dim=state_dim,
action_dims=action_dims,
hidden_layers=agent_config.get("hidden_layers", [256, 256]),
learning_rate=agent_config.get("learning_rate", 3e-4),
gamma=agent_config.get("gamma", 0.99),
gae_lambda=agent_config.get("gae_lambda", 0.95),
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
value_coef=agent_config.get("value_coef", 0.5),
entropy_coef=agent_config.get("entropy_coef", 0.01),
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
ppo_epochs=agent_config.get("ppo_epochs", 4),
minibatch_size=agent_config.get("batch_size", 64),
group_size=group_size,
group_advantage_coef=agent_config.get("group_advantage_coef", 0.35),
advantage_epsilon=agent_config.get("advantage_epsilon", 1e-8),
device=agent_config.get("device", "cuda"),
lr_schedule=agent_config.get("lr_schedule", "cosine"),
total_episodes=train_config["num_episodes"],
)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_stds = []
episode_hard_brakes = []
policy_losses = []
value_losses = []
entropies = []
best_reward = -float("inf")
print("Starting training...\n")
try:
pending_log_rows = []
for group_start in range(1, num_episodes + 1, group_size):
group_seed = base_seed + ((group_start - 1) // group_size) + 1
group_end = min(group_start + group_size - 1, num_episodes)
pending_log_rows.clear()
for episode in range(group_start, group_end + 1):
state = env.reset(seed=group_seed)
episode_reward = 0.0
episode_throughput = 0.0
episode_speed = 0.0
episode_speed_std = 0.0
episode_r_flow = 0.0
episode_r_var = 0.0
episode_r_brake = 0.0
episode_r_penalty = 0.0
episode_brakes = 0
done = False
step = 0
pbar = tqdm(
total=env.episode_length,
desc=f"Ep {episode}/{num_episodes}",
leave=False,
)
while not done:
action, log_prob, value = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, value, log_prob, done)
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_std += info["speed_std"] * 3.6
episode_r_flow += info["r_flow"]
episode_r_var += info["r_var"]
episode_r_brake += info["r_brake"]
episode_r_penalty += info["r_penalty"]
episode_brakes += info["num_hard_brakes"]
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
agent.finish_episode(episode_reward)
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_std = episode_speed_std / max(step, 1)
avg_r_flow = episode_r_flow / max(step, 1)
avg_r_var = episode_r_var / max(step, 1)
avg_r_brake = episode_r_brake / max(step, 1)
avg_r_penalty = episode_r_penalty / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_stds.append(avg_speed_std)
episode_hard_brakes.append(episode_brakes)
pending_log_rows.append(
{
"episode": episode,
"reward": episode_reward,
"avg_tp": avg_tp,
"avg_speed": avg_speed,
"avg_speed_std": avg_speed_std,
"avg_r_flow": avg_r_flow,
"avg_r_var": avg_r_var,
"avg_r_brake": avg_r_brake,
"avg_r_penalty": avg_r_penalty,
"episode_brakes": episode_brakes,
}
)
episode_summary = {
"episode": episode,
"reward": float(episode_reward),
"avg_throughput": float(avg_tp),
"avg_mean_speed_kmh": float(avg_speed),
"avg_speed_std_kmh": float(avg_speed_std),
"avg_r_flow": float(avg_r_flow),
"avg_r_var": float(avg_r_var),
"avg_r_brake": float(avg_r_brake),
"avg_r_penalty": float(avg_r_penalty),
"hard_brakes": int(episode_brakes),
"group_seed": int(group_seed),
}
save_training_episode_artifacts(
log_dir=log_dir,
episode=episode,
episode_metrics=env.episode_metrics,
control_edges=env.control_edges,
summary=episode_summary,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Speed Std: {avg_speed_std:.2f} km/h")
print(
" R(flow/var/brake/pen): "
f"{avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}"
)
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
train_stats = agent.update()
if train_stats:
policy_losses.append(train_stats["policy_loss"])
value_losses.append(train_stats["value_loss"])
entropies.append(train_stats["entropy"])
print(
f"\nGroup update episodes {group_start}-{group_end} | "
f"seed={group_seed} | "
f"policy_loss={train_stats['policy_loss']:.4f} | "
f"entropy={train_stats['entropy']:.4f} | "
f"group_score_std={train_stats['group_score_std']:.4f}"
)
for row in pending_log_rows:
if train_stats:
logger.log(
row["episode"],
row["reward"],
row["avg_tp"],
row["avg_speed"],
speed_std=row["avg_speed_std"],
r_flow=row["avg_r_flow"],
r_var=row["avg_r_var"],
r_brake=row["avg_r_brake"],
r_penalty=row["avg_r_penalty"],
hard_brakes=row["episode_brakes"],
policy_loss=train_stats["policy_loss"],
value_loss=train_stats["value_loss"],
entropy=train_stats["entropy"],
)
else:
logger.log(
row["episode"],
row["reward"],
row["avg_tp"],
row["avg_speed"],
speed_std=row["avg_speed_std"],
r_flow=row["avg_r_flow"],
r_var=row["avg_r_var"],
r_brake=row["avg_r_brake"],
r_penalty=row["avg_r_penalty"],
hard_brakes=row["episode_brakes"],
)
except KeyboardInterrupt:
print("\nTraining interrupted, saving current model...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
plot_training_curves(
episode_rewards,
episode_throughputs,
episode_mean_speeds,
episode_speed_stds,
episode_hard_brakes,
policy_losses,
value_losses,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("Training complete")
print(f" Best reward: {best_reward:.2f}")
print(f" Checkpoints: {checkpoint_dir}")
print(f" Logs: {log_dir}")
print("=" * 70)