添加td3模型作为基线模型
This commit is contained in:
parent
b830631aa9
commit
2cbaa27b8b
|
|
@ -1,3 +1,6 @@
|
|||
[[tool.uv.index]]
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
default = true
|
||||
[project]
|
||||
name = "ctm"
|
||||
version = "0.1.0"
|
||||
|
|
@ -5,15 +8,24 @@ description = "DQN-based Dynamic Speed Limit Control with Cell Transmission Mode
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"torch>=2.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"matplotlib>=3.7.0",
|
||||
"pyyaml>=6.0",
|
||||
"tqdm>=4.65.0",
|
||||
"pandas>=2.3.3",
|
||||
"eclipse-sumo>=1.20.0",
|
||||
"torch==2.4.1+cu124",
|
||||
"numpy>=2.4.3",
|
||||
"matplotlib>=3.10.8",
|
||||
"stable-baselines3>=2.7.1",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://pypi.org/simple"
|
||||
name = "tuna"
|
||||
url = "https://pypi.mirrors.ustc.edu.cn/simple/"
|
||||
default = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu124"
|
||||
url = "https://mirrors.nju.edu.cn/pytorch/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu124" }
|
||||
|
|
|
|||
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
TD3 Agent using Stable-Baselines3
|
||||
适配 MultiDiscrete 动作空间的 VSL 控制
|
||||
"""
|
||||
import numpy as np
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
|
||||
|
||||
class MultiDiscreteWrapper(gym.Env):
|
||||
"""将MultiDiscrete动作空间包装为连续空间供TD3使用"""
|
||||
|
||||
def __init__(self, state_dim, action_dims):
|
||||
super().__init__()
|
||||
self.state_dim = state_dim
|
||||
self.action_dims = action_dims
|
||||
self.num_zones = len(action_dims)
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32
|
||||
)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
return np.zeros(self.state_dim, dtype=np.float32), {}
|
||||
|
||||
def step(self, action):
|
||||
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
|
||||
|
||||
|
||||
class TD3Agent:
|
||||
"""TD3智能体包装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
action_dims: list,
|
||||
learning_rate: float = 3e-4,
|
||||
buffer_size: int = 100000,
|
||||
learning_starts: int = 1000,
|
||||
batch_size: int = 256,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
policy_delay: int = 2,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.state_dim = state_dim
|
||||
self.action_dims = action_dims
|
||||
self.num_zones = len(action_dims)
|
||||
self.device = device
|
||||
self.learning_starts = learning_starts
|
||||
|
||||
# 创建虚拟环境
|
||||
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
|
||||
|
||||
# 动作噪声
|
||||
action_noise = NormalActionNoise(
|
||||
mean=np.zeros(self.num_zones),
|
||||
sigma=0.1 * np.ones(self.num_zones)
|
||||
)
|
||||
|
||||
# 创建TD3模型
|
||||
self.model = TD3(
|
||||
"MlpPolicy",
|
||||
env=dummy_env,
|
||||
learning_rate=learning_rate,
|
||||
buffer_size=buffer_size,
|
||||
learning_starts=learning_starts,
|
||||
batch_size=batch_size,
|
||||
tau=tau,
|
||||
gamma=gamma,
|
||||
policy_delay=policy_delay,
|
||||
action_noise=action_noise,
|
||||
device=device,
|
||||
verbose=0,
|
||||
)
|
||||
|
||||
def select_action(self, state: np.ndarray, deterministic: bool = False):
|
||||
"""选择动作并转换为离散动作"""
|
||||
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
|
||||
|
||||
# 映射到离散动作
|
||||
discrete_action = np.array([
|
||||
int(cont * (self.action_dims[i] - 1) + 0.5)
|
||||
for i, cont in enumerate(continuous_action)
|
||||
])
|
||||
discrete_action = np.clip(discrete_action, 0, [d-1 for d in self.action_dims])
|
||||
|
||||
return discrete_action, 0.0, 0.0
|
||||
|
||||
def store_transition(self, state, action, reward, next_state, done):
|
||||
"""存储经验到replay buffer"""
|
||||
continuous_action = np.array([
|
||||
action[i] / (self.action_dims[i] - 1)
|
||||
for i in range(self.num_zones)
|
||||
], dtype=np.float32)
|
||||
|
||||
self.model.replay_buffer.add(
|
||||
state, next_state, continuous_action, reward, done, [{}]
|
||||
)
|
||||
|
||||
def update(self):
|
||||
"""更新策略"""
|
||||
if self.model.replay_buffer.size() < self.learning_starts:
|
||||
return {}
|
||||
|
||||
self.model.train(gradient_steps=1)
|
||||
return {"actor_loss": 0.0, "critic_loss": 0.0}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
self.model.save(path)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
self.model = TD3.load(path, device=self.device)
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
基于 SUMO+TraCI 的 TD3 训练脚本
|
||||
使用 Stable-Baselines3 的 TD3 算法
|
||||
"""
|
||||
import os
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
|
||||
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
|
||||
from td3_agent import TD3Agent
|
||||
from training_logger import TrainingLogger
|
||||
|
||||
|
||||
def train_sumo_td3():
|
||||
"""SUMO 环境下的 TD3 训练主函数"""
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
agent_config = config.get("agent", {})
|
||||
train_config = config["training"]
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints_sumo_td3", timestamp)
|
||||
log_dir = os.path.join("logs_sumo_td3", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "td3")
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
|
||||
state_dim = env.state_dim
|
||||
action_dims = [env.action_dim] * env.num_edges
|
||||
|
||||
print("=" * 70)
|
||||
print("TD3训练 - SUMO+TraCI VSL 环境")
|
||||
print("=" * 70)
|
||||
print(f" 状态维度: {state_dim}")
|
||||
print(f" 动作空间: {action_dims}")
|
||||
print(f" Episode 步数: {env.episode_length}")
|
||||
print(f" 控制间隔: {env.control_interval}s")
|
||||
print(f" 学习率: {agent_config.get('learning_rate', 3e-4)}")
|
||||
print(f" 设备: {agent_config.get('device', 'cuda')}")
|
||||
print()
|
||||
|
||||
agent = TD3Agent(
|
||||
state_dim=state_dim,
|
||||
action_dims=action_dims,
|
||||
learning_rate=agent_config.get("learning_rate", 3e-4),
|
||||
buffer_size=agent_config.get("buffer_size", 100000),
|
||||
learning_starts=agent_config.get("learning_starts", 1000),
|
||||
batch_size=agent_config.get("batch_size", 256),
|
||||
tau=agent_config.get("tau", 0.005),
|
||||
gamma=agent_config.get("gamma", 0.99),
|
||||
policy_delay=agent_config.get("policy_delay", 2),
|
||||
device=agent_config.get("device", "cuda"),
|
||||
)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_hard_brakes = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
print("开始训练...\n")
|
||||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0
|
||||
episode_throughput = 0
|
||||
episode_speed = 0
|
||||
episode_brakes = 0
|
||||
done = False
|
||||
step = 0
|
||||
|
||||
pbar = tqdm(
|
||||
total=env.episode_length,
|
||||
desc=f"Ep {episode}/{num_episodes}",
|
||||
leave=False,
|
||||
)
|
||||
|
||||
while not done:
|
||||
action, _, _ = agent.select_action(state, deterministic=False)
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
agent.store_transition(state, action, reward, next_state, done)
|
||||
agent.update()
|
||||
|
||||
episode_reward += reward
|
||||
episode_throughput += info["throughput"]
|
||||
episode_speed += info["mean_speed_kmh"]
|
||||
episode_brakes += info["num_hard_brakes"]
|
||||
state = next_state
|
||||
step += 1
|
||||
|
||||
pbar.set_postfix(
|
||||
r=f"{episode_reward:.1f}",
|
||||
tp=f"{info['throughput']:.0f}",
|
||||
v=f"{info['mean_speed_kmh']:.1f}",
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
avg_tp = episode_throughput / max(step, 1)
|
||||
avg_speed = episode_speed / max(step, 1)
|
||||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
logger.log(episode, episode_reward, avg_tp, avg_speed, episode_brakes)
|
||||
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
agent.save(os.path.join(checkpoint_dir, "model_best"))
|
||||
|
||||
if episode % log_freq == 0:
|
||||
recent_rewards = episode_rewards[-log_freq:]
|
||||
print(f"\nEpisode {episode}/{num_episodes}")
|
||||
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
|
||||
print(f" Throughput: {avg_tp:.1f} veh/h")
|
||||
print(f" Mean Speed: {avg_speed:.1f} km/h")
|
||||
|
||||
if episode % save_freq == 0:
|
||||
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}"))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n训练被中断,保存当前模型...")
|
||||
agent.save(os.path.join(checkpoint_dir, "model_interrupted"))
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
|
||||
|
||||
_plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
||||
print("=" * 70)
|
||||
print("训练完成!")
|
||||
print(f" 最佳奖励: {best_reward:.2f}")
|
||||
print(f" 模型目录: {checkpoint_dir}")
|
||||
print(f" 日志目录: {log_dir}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
def _plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes, save_path: str):
|
||||
"""绘制训练曲线"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||
window = 20
|
||||
|
||||
axes[0, 0].plot(rewards, alpha=0.4, color="blue")
|
||||
if len(rewards) > window:
|
||||
ma = np.convolve(rewards, np.ones(window) / window, mode="valid")
|
||||
axes[0, 0].plot(range(window - 1, len(rewards)), ma, "r-", linewidth=2)
|
||||
axes[0, 0].set_xlabel("Episode")
|
||||
axes[0, 0].set_ylabel("Total Reward")
|
||||
axes[0, 0].set_title("Episode Reward")
|
||||
axes[0, 0].grid(True, alpha=0.3)
|
||||
|
||||
axes[0, 1].plot(throughputs, alpha=0.4, color="green")
|
||||
if len(throughputs) > window:
|
||||
ma = np.convolve(throughputs, np.ones(window) / window, mode="valid")
|
||||
axes[0, 1].plot(range(window - 1, len(throughputs)), ma, "r-", linewidth=2)
|
||||
axes[0, 1].set_xlabel("Episode")
|
||||
axes[0, 1].set_ylabel("Avg Throughput (veh/h)")
|
||||
axes[0, 1].set_title("Throughput")
|
||||
axes[0, 1].grid(True, alpha=0.3)
|
||||
|
||||
axes[1, 0].plot(mean_speeds, alpha=0.4, color="orange")
|
||||
if len(mean_speeds) > window:
|
||||
ma = np.convolve(mean_speeds, np.ones(window) / window, mode="valid")
|
||||
axes[1, 0].plot(range(window - 1, len(mean_speeds)), ma, "r-", linewidth=2)
|
||||
axes[1, 0].set_xlabel("Episode")
|
||||
axes[1, 0].set_ylabel("Mean Speed (km/h)")
|
||||
axes[1, 0].set_title("Mean Speed")
|
||||
axes[1, 0].grid(True, alpha=0.3)
|
||||
|
||||
axes[1, 1].plot(hard_brakes, alpha=0.4, color="red")
|
||||
if len(hard_brakes) > window:
|
||||
ma = np.convolve(hard_brakes, np.ones(window) / window, mode="valid")
|
||||
axes[1, 1].plot(range(window - 1, len(hard_brakes)), ma, "r-", linewidth=2)
|
||||
axes[1, 1].set_xlabel("Episode")
|
||||
axes[1, 1].set_ylabel("Hard Brakes Count")
|
||||
axes[1, 1].set_title("Hard Brakes")
|
||||
axes[1, 1].grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
print(f"训练曲线已保存: {save_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_td3()
|
||||
Loading…
Reference in New Issue