ctm-dqn/training/train_td3.py

246 lines
9.4 KiB
Python

"""
鍩轰簬 SUMO+TraCI 鐨?TD3 璁粌鑴氭湰
浣跨敤 Stable-Baselines3 鐨?TD3 绠楁硶
"""
import os
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from envs.reward_system import REWARD_COMPONENT_COLUMNS, average_reward_components, init_reward_component_totals
from agents.td3_agent import TD3Agent
from utils.config import get_agent_config, get_training_config
from utils.episode_artifacts import save_training_episode_artifacts
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import resolve_run_dirs, write_shared_run_config
from utils.seeding import derive_seed, resolve_base_seed, set_global_seed
def train_sumo_td3(
log_dir=None,
checkpoint_dir=None,
run_timestamp=None,
model_name: str = "td3",
config_key: str = "td3",
display_name: str = "TD3",
agent_class=TD3Agent,
):
"""Train TD3 on the SUMO VSL environment."""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, config_key)
train_config = get_training_config(config)
base_seed = resolve_base_seed(train_config)
set_global_seed(base_seed)
resolved_run_timestamp, checkpoint_dir, log_dir = resolve_run_dirs(
model_name,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
runtime_config["runtime"]["evaluation_mode"] = False
runtime_config["runtime"]["run_timestamp"] = resolved_run_timestamp
write_shared_run_config(
runtime_config,
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
logger = TrainingLogger(log_dir, model_name)
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_controlled_edges
print("=" * 70)
print(f"{display_name} training - SUMO+TraCI VSL")
print("=" * 70)
print(f" State dim: {state_dim}")
print(f" Action dims: {action_dims}")
print(f" Episode length: {env.episode_length}")
print(f" Control interval: {env.control_interval}s")
print(f" Learning rate: {agent_config.get('learning_rate', 3e-4)}")
print(f" Device: {agent_config.get('device', 'cuda')}")
print()
common_kwargs = dict(
state_dim=state_dim,
action_dims=action_dims,
learning_rate=agent_config.get("learning_rate", 3e-4),
buffer_size=agent_config.get("buffer_size", 100000),
learning_starts=agent_config.get("learning_starts", 1000),
batch_size=agent_config.get("batch_size", 256),
tau=agent_config.get("tau", 0.005),
gamma=agent_config.get("gamma", 0.99),
exploration_sigma=agent_config.get("exploration_sigma", 0.1),
device=agent_config.get("device", "cuda"),
actor_hidden_dims=agent_config.get("actor_hidden_dims"),
critic_hidden_dims=agent_config.get("critic_hidden_dims"),
activation_fn=agent_config.get("activation_fn", "relu"),
)
if "policy_delay" in agent_config:
common_kwargs["policy_delay"] = agent_config.get("policy_delay", 2)
if config_key == "sctd3":
common_kwargs.update(
edge_feature_dim=env.features_per_edge,
total_edge_count=env.num_edges,
controlled_start_index=env.controlled_edge_start_index,
extractor_feature_dim=agent_config.get("extractor_feature_dim", 128),
extractor_edge_hidden_dim=agent_config.get("extractor_edge_hidden_dim", 16),
extractor_global_hidden_dim=agent_config.get("extractor_global_hidden_dim", 32),
extractor_spatial_blocks=agent_config.get("extractor_spatial_blocks", 1),
extractor_kernel_size=agent_config.get("extractor_kernel_size", 3),
)
common_kwargs["seed"] = base_seed
agent = agent_class(**common_kwargs)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_variance_norms = []
episode_ttc_risks = []
best_reward = -float("inf")
print("Starting training...\n")
try:
for episode in range(1, num_episodes + 1):
seed = derive_seed(base_seed, episode)
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_speed_variance_norm = 0.0
episode_reward_components = init_reward_component_totals()
episode_ttc_risk = 0.0
done = False
step = 0
pbar = tqdm(
total=env.episode_length,
desc=f"Ep {episode}/{num_episodes}",
leave=False,
)
while not done:
action, _, _ = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
agent.update()
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_variance_norm += info["speed_variance_norm"]
for column in REWARD_COMPONENT_COLUMNS:
episode_reward_components[column] += float(info.get(column, 0.0))
episode_ttc_risk += float(info.get("ttc_risk_rate", 0.0))
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_variance_norm = episode_speed_variance_norm / max(step, 1)
avg_reward_components = average_reward_components(episode_reward_components, step)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_variance_norms.append(avg_speed_variance_norm)
episode_ttc_risks.append(episode_ttc_risk)
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_variance_norm=avg_speed_variance_norm,
reward_components=avg_reward_components,
ttc_risk=episode_ttc_risk,
)
episode_summary = {
"episode": episode,
"reward": float(episode_reward),
"avg_throughput": float(avg_tp),
"avg_mean_speed_kmh": float(avg_speed),
"avg_speed_variance_norm": float(avg_speed_variance_norm),
"ttc_risk": float(episode_ttc_risk),
}
for column, value in avg_reward_components.items():
episode_summary[f"avg_{column}"] = float(value)
save_training_episode_artifacts(
log_dir=log_dir,
episode=episode,
episode_metrics=env.episode_metrics,
control_edges=env.control_edges,
summary=episode_summary,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Normalized Speed Variance: {avg_speed_variance_norm:.4f}")
print(
" Reward Components: "
+ ", ".join(
f"{column}={avg_reward_components[column]:.3f}"
for column in REWARD_COMPONENT_COLUMNS
)
)
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}"))
except KeyboardInterrupt:
print("\nTraining interrupted, saving current model...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_variance_norms, episode_ttc_risks,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("Training complete")
print(f" Best reward: {best_reward:.2f}")
print(f" Model dir: {checkpoint_dir}")
print(f" Log dir: {log_dir}")
print("=" * 70)