ctm-dqn/training/train_ddpg.py

198 lines
7.1 KiB
Python

"""
基于 SUMO+TraCI 的 DDPG 训练脚本
使用 Stable-Baselines3 的 DDPG 算法
"""
import argparse
import os
import copy
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
from datetime import datetime
from tqdm import tqdm
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
from agents.ddpg_agent import DDPGAgent
from utils.config import get_agent_config, get_training_config
from utils.logger import TrainingLogger
from utils.plot import plot_training_curves
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
"""SUMO 环境下的 DDPG 训练主函数"""
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = get_agent_config(config, "ddpg")
train_config = get_training_config(config)
_, checkpoint_dir, log_dir = resolve_run_dirs(
"ddpg",
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
runtime_config = copy.deepcopy(config)
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(runtime_config, f)
logger = TrainingLogger(log_dir, "ddpg")
env = SUMOEdgeVSLEnvironment(runtime_config)
state_dim = env.state_dim
action_dims = [env.action_dim] * env.num_edges
print("=" * 70)
print("DDPG训练 - SUMO+TraCI VSL 环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 动作空间: {action_dims}")
print(f" Episode 步数: {env.episode_length}")
print(f" 控制间隔: {env.control_interval}s")
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
print(f" 设备: {agent_config.get('device', 'cuda')}")
print()
agent = DDPGAgent(
state_dim=state_dim,
action_dims=action_dims,
learning_rate=agent_config.get("learning_rate", 3e-4),
buffer_size=agent_config.get("buffer_size", 100000),
learning_starts=agent_config.get("learning_starts", 1000),
batch_size=agent_config.get("batch_size", 256),
tau=agent_config.get("tau", 0.005),
gamma=agent_config.get("gamma", 0.99),
exploration_sigma=agent_config.get("exploration_sigma", 0.1),
device=agent_config.get("device", "cuda"),
)
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
episode_speed_stds = []
episode_hard_brakes = []
best_reward = -float("inf")
print("开始训练...\n")
try:
for episode in range(1, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
episode_speed_std = 0
episode_r_flow = 0
episode_r_var = 0
episode_r_brake = 0
episode_r_penalty = 0
episode_brakes = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
action, _, _ = agent.select_action(state, deterministic=False)
next_state, reward, done, info = env.step(action)
agent.store_transition(state, action, reward, next_state, done)
agent.update()
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
episode_speed_std += info["speed_std"] * 3.6
episode_r_flow += info["r_flow"]
episode_r_var += info["r_var"]
episode_r_brake += info["r_brake"]
episode_r_penalty += info["r_penalty"]
episode_brakes += info["num_hard_brakes"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}", v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
avg_speed_std = episode_speed_std / max(step, 1)
avg_r_flow = episode_r_flow / max(step, 1)
avg_r_var = episode_r_var / max(step, 1)
avg_r_brake = episode_r_brake / max(step, 1)
avg_r_penalty = episode_r_penalty / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
episode_speed_stds.append(avg_speed_std)
episode_hard_brakes.append(episode_brakes)
logger.log(
episode, episode_reward, avg_tp, avg_speed,
speed_std=avg_speed_std,
r_flow=avg_r_flow,
r_var=avg_r_var,
r_brake=avg_r_brake,
r_penalty=avg_r_penalty,
hard_brakes=episode_brakes,
)
if episode_reward > best_reward:
best_reward = episode_reward
agent.save(os.path.join(checkpoint_dir, "model_best"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
print(f" Speed Std: {avg_speed_std:.2f} km/h")
print(f" R(flow/var/brake/pen): {avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}")
if episode % save_freq == 0:
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}"))
except KeyboardInterrupt:
print("\n训练被中断,保存当前模型...")
agent.save(os.path.join(checkpoint_dir, "model_interrupted"))
finally:
env.close()
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
plot_training_curves(
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
save_path=os.path.join(log_dir, "training_curves.png"),
)
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)
if __name__ == "__main__":
parser = add_run_dir_args(argparse.ArgumentParser())
args = parser.parse_args()
train_sumo_ddpg(
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
run_timestamp=args.run_timestamp,
)