From 2cbaa27b8bd249afbd7a234619d92a37aecf6044 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Wed, 1 Apr 2026 00:18:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0td3=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E4=BD=9C=E4=B8=BA=E5=9F=BA=E7=BA=BF=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 22 +++-- td3_agent.py | 120 ++++++++++++++++++++++++++ train_sumo_td3.py | 209 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 346 insertions(+), 5 deletions(-) create mode 100644 td3_agent.py create mode 100644 train_sumo_td3.py diff --git a/pyproject.toml b/pyproject.toml index 91c5e4b..082e335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[[tool.uv.index]] +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +default = true [project] name = "ctm" version = "0.1.0" @@ -5,15 +8,24 @@ description = "DQN-based Dynamic Speed Limit Control with Cell Transmission Mode readme = "README.md" requires-python = ">=3.12" dependencies = [ - "torch>=2.0.0", - "numpy>=1.24.0", - "matplotlib>=3.7.0", "pyyaml>=6.0", "tqdm>=4.65.0", - "pandas>=2.3.3", "eclipse-sumo>=1.20.0", + "torch==2.4.1+cu124", + "numpy>=2.4.3", + "matplotlib>=3.10.8", + "stable-baselines3>=2.7.1", ] [[tool.uv.index]] -url = "https://pypi.org/simple" +name = "tuna" +url = "https://pypi.mirrors.ustc.edu.cn/simple/" default = true + +[[tool.uv.index]] +name = "pytorch-cu124" +url = "https://mirrors.nju.edu.cn/pytorch/whl/cu124" +explicit = true + +[tool.uv.sources] +torch = { index = "pytorch-cu124" } diff --git a/td3_agent.py b/td3_agent.py new file mode 100644 index 0000000..cd46c32 --- /dev/null +++ b/td3_agent.py @@ -0,0 +1,120 @@ +""" +TD3 Agent using Stable-Baselines3 +适配 MultiDiscrete 动作空间的 VSL 控制 +""" +import numpy as np +from stable_baselines3 import TD3 +from stable_baselines3.common.noise import NormalActionNoise +import gymnasium as gym +from gymnasium import spaces + + +class MultiDiscreteWrapper(gym.Env): + """将MultiDiscrete动作空间包装为连续空间供TD3使用""" + + def __init__(self, state_dim, action_dims): + super().__init__() + self.state_dim = state_dim + self.action_dims = action_dims + self.num_zones = len(action_dims) + + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32 + ) + self.action_space = spaces.Box( + low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32 + ) + + def reset(self, seed=None, options=None): + return np.zeros(self.state_dim, dtype=np.float32), {} + + def step(self, action): + return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {} + + +class TD3Agent: + """TD3智能体包装器""" + + def __init__( + self, + state_dim: int, + action_dims: list, + learning_rate: float = 3e-4, + buffer_size: int = 100000, + learning_starts: int = 1000, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + policy_delay: int = 2, + device: str = "cuda", + ): + self.state_dim = state_dim + self.action_dims = action_dims + self.num_zones = len(action_dims) + self.device = device + self.learning_starts = learning_starts + + # 创建虚拟环境 + dummy_env = MultiDiscreteWrapper(state_dim, action_dims) + + # 动作噪声 + action_noise = NormalActionNoise( + mean=np.zeros(self.num_zones), + sigma=0.1 * np.ones(self.num_zones) + ) + + # 创建TD3模型 + self.model = TD3( + "MlpPolicy", + env=dummy_env, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + policy_delay=policy_delay, + action_noise=action_noise, + device=device, + verbose=0, + ) + + def select_action(self, state: np.ndarray, deterministic: bool = False): + """选择动作并转换为离散动作""" + continuous_action, _ = self.model.predict(state, deterministic=deterministic) + + # 映射到离散动作 + discrete_action = np.array([ + int(cont * (self.action_dims[i] - 1) + 0.5) + for i, cont in enumerate(continuous_action) + ]) + discrete_action = np.clip(discrete_action, 0, [d-1 for d in self.action_dims]) + + return discrete_action, 0.0, 0.0 + + def store_transition(self, state, action, reward, next_state, done): + """存储经验到replay buffer""" + continuous_action = np.array([ + action[i] / (self.action_dims[i] - 1) + for i in range(self.num_zones) + ], dtype=np.float32) + + self.model.replay_buffer.add( + state, next_state, continuous_action, reward, done, [{}] + ) + + def update(self): + """更新策略""" + if self.model.replay_buffer.size() < self.learning_starts: + return {} + + self.model.train(gradient_steps=1) + return {"actor_loss": 0.0, "critic_loss": 0.0} + + def save(self, path: str): + """保存模型""" + self.model.save(path) + + def load(self, path: str): + """加载模型""" + self.model = TD3.load(path, device=self.device) diff --git a/train_sumo_td3.py b/train_sumo_td3.py new file mode 100644 index 0000000..20ac763 --- /dev/null +++ b/train_sumo_td3.py @@ -0,0 +1,209 @@ +""" +基于 SUMO+TraCI 的 TD3 训练脚本 +使用 Stable-Baselines3 的 TD3 算法 +""" +import os +import yaml +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from datetime import datetime +from tqdm import tqdm + +from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment +from td3_agent import TD3Agent +from training_logger import TrainingLogger + + +def train_sumo_td3(): + """SUMO 环境下的 TD3 训练主函数""" + with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + agent_config = config.get("agent", {}) + train_config = config["training"] + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + checkpoint_dir = os.path.join("checkpoints_sumo_td3", timestamp) + log_dir = os.path.join("logs_sumo_td3", timestamp) + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: + yaml.dump(config, f) + + logger = TrainingLogger(log_dir, "td3") + env = SUMOEdgeVSLEnvironment(config) + + state_dim = env.state_dim + action_dims = [env.action_dim] * env.num_edges + + print("=" * 70) + print("TD3训练 - SUMO+TraCI VSL 环境") + print("=" * 70) + print(f" 状态维度: {state_dim}") + print(f" 动作空间: {action_dims}") + print(f" Episode 步数: {env.episode_length}") + print(f" 控制间隔: {env.control_interval}s") + print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}") + print(f" 设备: {agent_config.get('device', 'cuda')}") + print() + + agent = TD3Agent( + state_dim=state_dim, + action_dims=action_dims, + learning_rate=agent_config.get("learning_rate", 3e-4), + buffer_size=agent_config.get("buffer_size", 100000), + learning_starts=agent_config.get("learning_starts", 1000), + batch_size=agent_config.get("batch_size", 256), + tau=agent_config.get("tau", 0.005), + gamma=agent_config.get("gamma", 0.99), + policy_delay=agent_config.get("policy_delay", 2), + device=agent_config.get("device", "cuda"), + ) + + num_episodes = train_config["num_episodes"] + save_freq = train_config.get("save_freq", 50) + log_freq = train_config.get("log_freq", 10) + base_seed = train_config.get("random_seed", 42) + + episode_rewards = [] + episode_throughputs = [] + episode_mean_speeds = [] + episode_hard_brakes = [] + best_reward = -float("inf") + + print("开始训练...\n") + + try: + for episode in range(1, num_episodes + 1): + seed = base_seed + episode + state = env.reset(seed=seed) + episode_reward = 0 + episode_throughput = 0 + episode_speed = 0 + episode_brakes = 0 + done = False + step = 0 + + pbar = tqdm( + total=env.episode_length, + desc=f"Ep {episode}/{num_episodes}", + leave=False, + ) + + while not done: + action, _, _ = agent.select_action(state, deterministic=False) + next_state, reward, done, info = env.step(action) + + agent.store_transition(state, action, reward, next_state, done) + agent.update() + + episode_reward += reward + episode_throughput += info["throughput"] + episode_speed += info["mean_speed_kmh"] + episode_brakes += info["num_hard_brakes"] + state = next_state + step += 1 + + pbar.set_postfix( + r=f"{episode_reward:.1f}", + tp=f"{info['throughput']:.0f}", + v=f"{info['mean_speed_kmh']:.1f}", + ) + pbar.update(1) + + pbar.close() + + avg_tp = episode_throughput / max(step, 1) + avg_speed = episode_speed / max(step, 1) + episode_rewards.append(episode_reward) + episode_throughputs.append(avg_tp) + episode_mean_speeds.append(avg_speed) + episode_hard_brakes.append(episode_brakes) + + logger.log(episode, episode_reward, avg_tp, avg_speed, episode_brakes) + + if episode_reward > best_reward: + best_reward = episode_reward + agent.save(os.path.join(checkpoint_dir, "model_best")) + + if episode % log_freq == 0: + recent_rewards = episode_rewards[-log_freq:] + print(f"\nEpisode {episode}/{num_episodes}") + print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") + print(f" Throughput: {avg_tp:.1f} veh/h") + print(f" Mean Speed: {avg_speed:.1f} km/h") + + if episode % save_freq == 0: + agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}")) + + except KeyboardInterrupt: + print("\n训练被中断,保存当前模型...") + agent.save(os.path.join(checkpoint_dir, "model_interrupted")) + finally: + env.close() + + agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}")) + + _plot_training_curves( + episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + save_path=os.path.join(log_dir, "training_curves.png"), + ) + + print("=" * 70) + print("训练完成!") + print(f" 最佳奖励: {best_reward:.2f}") + print(f" 模型目录: {checkpoint_dir}") + print(f" 日志目录: {log_dir}") + print("=" * 70) + + +def _plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes, save_path: str): + """绘制训练曲线""" + fig, axes = plt.subplots(2, 2, figsize=(15, 10)) + window = 20 + + axes[0, 0].plot(rewards, alpha=0.4, color="blue") + if len(rewards) > window: + ma = np.convolve(rewards, np.ones(window) / window, mode="valid") + axes[0, 0].plot(range(window - 1, len(rewards)), ma, "r-", linewidth=2) + axes[0, 0].set_xlabel("Episode") + axes[0, 0].set_ylabel("Total Reward") + axes[0, 0].set_title("Episode Reward") + axes[0, 0].grid(True, alpha=0.3) + + axes[0, 1].plot(throughputs, alpha=0.4, color="green") + if len(throughputs) > window: + ma = np.convolve(throughputs, np.ones(window) / window, mode="valid") + axes[0, 1].plot(range(window - 1, len(throughputs)), ma, "r-", linewidth=2) + axes[0, 1].set_xlabel("Episode") + axes[0, 1].set_ylabel("Avg Throughput (veh/h)") + axes[0, 1].set_title("Throughput") + axes[0, 1].grid(True, alpha=0.3) + + axes[1, 0].plot(mean_speeds, alpha=0.4, color="orange") + if len(mean_speeds) > window: + ma = np.convolve(mean_speeds, np.ones(window) / window, mode="valid") + axes[1, 0].plot(range(window - 1, len(mean_speeds)), ma, "r-", linewidth=2) + axes[1, 0].set_xlabel("Episode") + axes[1, 0].set_ylabel("Mean Speed (km/h)") + axes[1, 0].set_title("Mean Speed") + axes[1, 0].grid(True, alpha=0.3) + + axes[1, 1].plot(hard_brakes, alpha=0.4, color="red") + if len(hard_brakes) > window: + ma = np.convolve(hard_brakes, np.ones(window) / window, mode="valid") + axes[1, 1].plot(range(window - 1, len(hard_brakes)), ma, "r-", linewidth=2) + axes[1, 1].set_xlabel("Episode") + axes[1, 1].set_ylabel("Hard Brakes Count") + axes[1, 1].set_title("Hard Brakes") + axes[1, 1].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"训练曲线已保存: {save_path}") + + +if __name__ == "__main__": + train_sumo_td3()