ctm-dqn/training/train_value_based.py

269 lines
10 KiB
Python

"""Shared training loop for value-based VSL agents."""
from __future__ import annotations
import copy
import inspect
import os
from typing import Callable
import matplotlib
import numpy as np
import yaml
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals
from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
matplotlib.use("Agg")
def _build_value_based_agent(agent_builder: Callable[..., object], env, agent_config: dict):
candidate_kwargs = {
"state_dim": env.state_dim,
"num_edges": env.num_controlled_edges,
"num_actions_per_edge": env.action_dim,
"hidden_dim": agent_config.get("hidden_dim", 256),
"mixing_hidden_dim": agent_config.get(
"mixing_hidden_dim",
agent_config.get("hidden_dim", 256),
),
"learning_rate": agent_config.get("learning_rate", 1e-3),
"gamma": agent_config.get("gamma", 0.99),
"epsilon_start": agent_config.get("epsilon_start", 1.0),
"epsilon_end": agent_config.get("epsilon_end", 0.01),
"epsilon_decay": agent_config.get("epsilon_decay", 200),
"buffer_size": agent_config.get("buffer_size", 10000),
"batch_size": agent_config.get("batch_size", 64),
"target_update": agent_config.get("target_update", 10),
"device": agent_config.get("device", "cuda"),
"edge_feature_dim": env.features_per_edge,
"time_feature_dim": 3,
"total_edge_count": env.num_edges,
"controlled_start_index": env.controlled_edge_start_index,
"num_corridor_blocks": agent_config.get("num_corridor_blocks", 2),
"corridor_kernel_size": agent_config.get("corridor_kernel_size", 5),
"corridor_dropout": agent_config.get("corridor_dropout", 0.05),
}
accepted = inspect.signature(agent_builder).parameters
filtered_kwargs = {
key: value
for key, value in candidate_kwargs.items()
if key in accepted
}
return agent_builder(**filtered_kwargs)
def train_sumo_value_based(
model_key: str,
model_label: str,
agent_builder: Callable[..., object],
*,
log_dir=None,
checkpoint_dir=None,
run_timestamp=None,
):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, model_key)
train_config = get_training_config(config)
base_seed = resolve_base_seed(train_config)
set_global_seed(base_seed)
_, checkpoint_dir, log_dir = resolve_run_dirs(
model_key,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
runtime_config["runtime"]["evaluation_mode"] = False
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, model_key)
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
num_edges = env.num_controlled_edges
num_actions_per_edge = env.action_dim
print("=" * 70)
print(f"{model_label} training - SUMO VSL environment")
print("=" * 70)
print(f" State dim: {state_dim}")
print(f" Controlled edges: {num_edges}")
print(f" Actions per edge: {num_actions_per_edge}")
print(f" Episode steps: {env.episode_length}")
print(f" Control interval: {env.control_interval}s")
print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}")
print(f" LR: {agent_config.get('learning_rate', 1e-3)}")
print(f" Device: {agent_config.get('device', 'cuda')}")
print()
print(f" Global seed: {base_seed if base_seed is not None else 'None (random)'}")
print()
agent = _build_value_based_agent(agent_builder, env, agent_config)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_variance_norms = []
episode_ttc_risks = []
value_losses = []
best_reward = -float("inf")
print("Starting training...\n")
try:
for episode in range(1, num_episodes + 1):
seed = derive_seed(base_seed, episode)
state = env.reset(seed=seed)
episode_reward = 0.0
episode_throughput = 0.0
episode_speed = 0.0
episode_speed_variance_norm = 0.0
episode_reward_components = init_reward_component_totals()
episode_ttc_risk = 0.0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
action, _, _ = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
train_stats = agent.update()
if train_stats:
value_losses.append(train_stats["value_loss"])
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_variance_norm += info["speed_variance_norm"]
for column in REWARD_COMPONENT_COLUMNS:
episode_reward_components[column] += float(info.get(column, 0.0))
episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0))
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1)
avg_reward_components = average_reward_components(episode_reward_components, step)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_variance_norms.append(avg_speed_variance_norm)
episode_ttc_risks.append(episode_ttc_risk)
loss_val = np.mean(value_losses[-100:]) if value_losses else None
logger.log(
episode,
episode_reward,
avg_tp,
avg_speed,
speed_variance_norm=avg_speed_variance_norm,
reward_components=avg_reward_components,
ttc_risk=episode_ttc_risk,
value_loss=loss_val,
)
episode_summary = {
"episode": episode,
"reward": float(episode_reward),
"avg_throughput": float(avg_tp),
"avg_mean_speed_kmh": float(avg_speed),
"avg_speed_variance_norm": float(avg_speed_variance_norm),
"ttc_risk": float(episode_ttc_risk),
"value_loss": float(loss_val) if loss_val is not None else None,
}
for column, value in avg_reward_components.items():
episode_summary[f"avg_{column}"] = float(value)
save_training_episode_artifacts(
log_dir=log_dir,
episode=episode,
episode_metrics=env.episode_metrics,
control_edges=env.control_edges,
summary=episode_summary,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}")
print(
" Reward Components: "
+ ", ".join(
f"{column}={avg_reward_components[column]:.3f}"
for column in REWARD_COMPONENT_COLUMNS
)
)
if loss_val is not None:
print(f" Value Loss: {loss_val:.4f}")
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
except KeyboardInterrupt:
print("\nTraining interrupted, saving current model...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
plot_training_curves(
episode_rewards,
episode_throughputs,
episode_mean_speeds,
episode_speed_variance_norms,
episode_ttc_risks,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("Training complete")
print(f" Best reward: {best_reward:.2f}")
print(f" Model dir: {checkpoint_dir}")
print(f" Log dir: {log_dir}")
print("=" * 70)