ctm-dqn/train_sumo_dqn.py

248 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
DQN训练脚本 - SUMO VSL环境
"""
import os
import sys
import yaml
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from sumo_edge_vsl_environment import SUMOEdgeVSLEnvironment
from dqn_agent import DQNAgent
from training_logger import TrainingLogger
def train_sumo_dqn(resume_checkpoint=None):
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
agent_config = config.get("agent", {})
train_config = config["training"]
start_episode = 1
if resume_checkpoint:
checkpoint_dir = os.path.dirname(resume_checkpoint)
log_dir = checkpoint_dir.replace("checkpoints_sumo_dqn", "logs_sumo_dqn")
filename = os.path.basename(resume_checkpoint)
if "ep" in filename:
start_episode = int(filename.split("ep")[1].split("_")[0].split(".")[0]) + 1
print(f"从 episode {start_episode} 继续训练")
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_dir = os.path.join("checkpoints_sumo_dqn", timestamp)
log_dir = os.path.join("logs_sumo_dqn", timestamp)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
yaml.dump(config, f)
logger = TrainingLogger(log_dir, "dqn", resume=bool(resume_checkpoint))
env = SUMOEdgeVSLEnvironment(config)
state_dim = env.state_dim
# DQN使用独立Q网络每条边独立5个动作
num_edges = env.num_edges
num_actions_per_edge = 5
print("=" * 70)
print("DQN训练 - SUMO VSL环境")
print("=" * 70)
print(f" 状态维度: {state_dim}")
print(f" 控制边数: {num_edges}")
print(f" 每边动作数: {num_actions_per_edge}")
print(f" Episode步数: {env.episode_length}")
print()
# 为每条边创建独立的DQN agent
agents = []
for i in range(num_edges):
agent = DQNAgent(
state_dim=state_dim,
num_actions=num_actions_per_edge,
hidden_dim=agent_config.get("hidden_dim", 256),
learning_rate=agent_config.get("learning_rate", 1e-3),
gamma=agent_config.get("gamma", 0.99),
epsilon_start=1.0,
epsilon_end=0.01,
epsilon_decay=200,
buffer_size=10000,
batch_size=agent_config.get("batch_size", 64),
target_update=10,
device=agent_config.get("device", "cuda")
)
agents.append(agent)
# 加载checkpoint
if resume_checkpoint:
for i, agent in enumerate(agents):
checkpoint_path = resume_checkpoint.replace("edge0", f"edge{i}")
if os.path.exists(checkpoint_path):
agent.load(checkpoint_path)
print(f"已加载模型: {resume_checkpoint}")
num_episodes = train_config["num_episodes"]
save_freq = train_config.get("save_freq", 50)
log_freq = train_config.get("log_freq", 10)
base_seed = train_config.get("random_seed", 42)
episode_rewards = []
episode_throughputs = []
episode_mean_speeds = []
losses = []
best_reward = -float("inf")
# 加载历史训练数据
if resume_checkpoint:
import csv
log_file = os.path.join(log_dir, "dqn_training_log.csv")
if os.path.exists(log_file):
with open(log_file, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
episode_rewards.append(float(row["reward"]))
episode_throughputs.append(float(row["throughput"]))
episode_mean_speeds.append(float(row["mean_speed"]))
if row["value_loss"]:
losses.append(float(row["value_loss"]))
best_reward = max(episode_rewards) if episode_rewards else -float("inf")
print(f"已加载 {len(episode_rewards)} 条历史记录")
print("开始训练...\n")
try:
for episode in range(start_episode, num_episodes + 1):
seed = base_seed + episode
state = env.reset(seed=seed)
episode_reward = 0
episode_throughput = 0
episode_speed = 0
done = False
step = 0
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
while not done:
# 每条边独立选择动作
action = []
for agent in agents:
action_idx = agent.select_action(state)
action.append(action_idx)
action = np.array(action)
next_state, reward, done, info = env.step(action)
# 每个agent存储转换并更新
for i, agent in enumerate(agents):
agent.store_transition(state, action[i], reward, next_state, done)
train_stats = agent.update()
if train_stats:
losses.append(train_stats["loss"])
episode_reward += reward
episode_throughput += info["throughput"]
episode_speed += info["mean_speed_kmh"]
state = next_state
step += 1
pbar.set_postfix(r=f"{episode_reward:.1f}", tp=f"{info['throughput']:.0f}",
v=f"{info['mean_speed_kmh']:.1f}")
pbar.update(1)
pbar.close()
if episode % 10 == 0:
for agent in agents:
agent.update_target_network()
avg_tp = episode_throughput / max(step, 1)
avg_speed = episode_speed / max(step, 1)
episode_rewards.append(episode_reward)
episode_throughputs.append(avg_tp)
episode_mean_speeds.append(avg_speed)
loss_val = np.mean(losses[-100:]) if losses else None
logger.log(episode, episode_reward, avg_tp, avg_speed, value_loss=loss_val)
if episode_reward > best_reward:
best_reward = episode_reward
for i, agent in enumerate(agents):
agent.save(os.path.join(checkpoint_dir, f"model_best_edge{i}.pt"))
if episode % log_freq == 0:
recent_rewards = episode_rewards[-log_freq:]
print(f"\nEpisode {episode}/{num_episodes}")
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
print(f" Throughput: {avg_tp:.1f} veh/h")
print(f" Mean Speed: {avg_speed:.1f} km/h")
if losses:
print(f" Loss: {np.mean(losses[-100:]):.4f}")
if episode % save_freq == 0:
for i, agent in enumerate(agents):
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}_edge{i}.pt"))
except KeyboardInterrupt:
print("\n训练被中断")
for i, agent in enumerate(agents):
agent.save(os.path.join(checkpoint_dir, f"model_interrupted_edge{i}.pt"))
finally:
env.close()
for i, agent in enumerate(agents):
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}_edge{i}.pt"))
# 绘制训练曲线
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0, 0].plot(episode_rewards, alpha=0.6)
window = 20
if len(episode_rewards) > window:
ma = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
axes[0, 0].plot(range(window-1, len(episode_rewards)), ma, 'r-', linewidth=2)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Total Reward')
axes[0, 0].set_title('DQN Training Reward')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 1].plot(episode_throughputs, 'g-', alpha=0.6)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Avg Throughput (veh/h)')
axes[0, 1].set_title('Throughput')
axes[0, 1].grid(True, alpha=0.3)
axes[1, 0].plot(episode_mean_speeds, 'orange', alpha=0.6)
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Mean Speed (km/h)')
axes[1, 0].set_title('Mean Speed')
axes[1, 0].grid(True, alpha=0.3)
if losses:
axes[1, 1].plot(losses, 'b-', alpha=0.6)
axes[1, 1].set_xlabel('Update Step')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Training Loss')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(log_dir, "training_curves.png"), dpi=150)
print(f"训练曲线已保存: {os.path.join(log_dir, 'training_curves.png')}")
print("=" * 70)
print("训练完成!")
print(f" 最佳奖励: {best_reward:.2f}")
print(f" 模型目录: {checkpoint_dir}")
print(f" 日志目录: {log_dir}")
print("=" * 70)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--resume", type=str, help="从checkpoint继续训练例如: checkpoints_sumo_dqn/xxx/model_ep500_edge0.pt")
args = parser.parse_args()
train_sumo_dqn(resume_checkpoint=args.resume)