ctm-dqn/training/train_td3.py

208 lines
7.3 KiB
Python

"""
基于 SUMO+TraCI 的 TD3 训练脚本
使用 Stable-Baselines3 的 TD3 算法
"""
import argparse
import os
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from agents.td3_agent import TD3Agent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 TD3 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "td3")
train_config = get_training_config(config)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"td3",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
logger = TrainingLogger(log_dir, "td3")
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_edges
print("=" * 70)
print("TD3训练 - SUMO+TraCI VSL 环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 动作空间: {action_dims}")
print(f" Episode 步数: {env.episode_length}")
print(f" 控制间隔: {env.control_interval}s")
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
print(f" 设备: {agent_config.get('device', 'cuda')}")
print()
agent = TD3Agent(
state_dim=state_dim,
action_dims=action_dims,
learning_rate=agent_config.get("learning_rate", 3e-4),
buffer_size=agent_config.get("buffer_size", 100000),
learning_starts=agent_config.get("learning_starts", 1000),
batch_size=agent_config.get("batch_size", 256),
tau=agent_config.get("tau", 0.005),
gamma=agent_config.get("gamma", 0.99),
policy_delay=agent_config.get("policy_delay", 2),
exploration_sigma=agent_config.get("exploration_sigma", 0.1),
device=agent_config.get("device", "cuda"),
)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_stds = []
episode_hard_brakes = []
best_reward = -float("inf")
print("开始训练...\n")
try:
for episode in range(1, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_speed_std = 0
episode_r_flow = 0
episode_r_var = 0
episode_r_brake = 0
episode_r_penalty = 0
episode_brakes = 0
done = False
step = 0
pbar = tqdm(
total=env.episode_length,
desc=f"Ep {episode}/{num_episodes}",
leave=False,
)
while not done:
action, _, _ = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
agent.update()
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_std += info["speed_std"] * 3.6
episode_r_flow += info["r_flow"]
episode_r_var += info["r_var"]
episode_r_brake += info["r_brake"]
episode_r_penalty += info["r_penalty"]
episode_brakes += info["num_hard_brakes"]
state = next_state
step += 1
pbar.set_postfix(
r=f"{episode_reward:.1f}",
tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}",
)
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_std = episode_speed_std / max(step, 1)
avg_r_flow = episode_r_flow / max(step, 1)
avg_r_var = episode_r_var / max(step, 1)
avg_r_brake = episode_r_brake / max(step, 1)
avg_r_penalty = episode_r_penalty / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_stds.append(avg_speed_std)
episode_hard_brakes.append(episode_brakes)
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_std=avg_speed_std,
r_flow=avg_r_flow,
r_var=avg_r_var,
r_brake=avg_r_brake,
r_penalty=avg_r_penalty,
hard_brakes=episode_brakes,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Speed Std: {avg_speed_std:.2f} km/h")
print(f" R(flow/var/brake/pen): {avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}")
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}"))
except KeyboardInterrupt:
print("\n训练被中断,保存当前模型...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)
if __name__ == "__main__":
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_td3(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)