From 43c7cb1dfa2595cab408f291b41bd88aa989e090 Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Thu, 9 Apr 2026 02:25:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0MAPPO=E5=AF=B9=E6=AF=94?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=EF=BC=8C=E4=BF=AE=E5=A4=8Dsumo=E4=BB=BF?= =?UTF-8?q?=E7=9C=9F=E6=97=B6=E8=BD=A6=E6=A3=80=E5=99=A8=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/mappo_agent.py | 306 ++++++++++++++++++++++++++++++++++++++++ config_sumo_vsl.yaml | 14 ++ envs/edge_vsl_env.py | 52 ++++++- run_all_training.py | 2 +- training/train_appo.py | 11 +- training/train_ddpg.py | 11 +- training/train_dqn.py | 11 +- training/train_mappo.py | 238 +++++++++++++++++++++++++++++++ training/train_ppo.py | 11 +- training/train_td3.py | 11 +- utils/plot.py | 63 +++++++-- 11 files changed, 697 insertions(+), 33 deletions(-) create mode 100644 agents/mappo_agent.py create mode 100644 training/train_mappo.py diff --git a/agents/mappo_agent.py b/agents/mappo_agent.py new file mode 100644 index 0000000..6197fb0 --- /dev/null +++ b/agents/mappo_agent.py @@ -0,0 +1,306 @@ +""" +MAPPO agent for SUMO VSL. + +This implementation uses parameter sharing across edge-agents: +- Actor: decentralized, one shared policy over per-edge local observations +- Critic: centralized, one value head over the global state +""" +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from typing import Dict, Tuple + + +class SharedActor(nn.Module): + def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256): + super().__init__() + self.net = nn.Sequential( + nn.Linear(local_obs_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, num_actions), + ) + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) + nn.init.constant_(module.bias, 0) + nn.init.orthogonal_(self.net[-1].weight, gain=0.01) + + def forward(self, local_obs: torch.Tensor) -> torch.Tensor: + return self.net(local_obs) + + +class CentralizedCritic(nn.Module): + def __init__(self, state_dim: int, hidden_dim: int = 256): + super().__init__() + self.net = nn.Sequential( + nn.Linear(state_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1), + ) + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight, gain=np.sqrt(2)) + nn.init.constant_(module.bias, 0) + nn.init.orthogonal_(self.net[-1].weight, gain=1.0) + + def forward(self, state: torch.Tensor) -> torch.Tensor: + return self.net(state) + + +class MAPPOAgent: + """Parameter-sharing MAPPO for edge-wise VSL control.""" + + def __init__( + self, + state_dim: int, + num_agents: int, + num_actions: int, + edge_feature_dim: int = 3, + time_feature_dim: int = 3, + hidden_dim: int = 256, + critic_hidden_dim: int = 256, + learning_rate: float = 3e-4, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_epsilon: float = 0.2, + value_coef: float = 0.5, + entropy_coef: float = 0.01, + max_grad_norm: float = 0.5, + ppo_epochs: int = 4, + minibatch_size: int = 15, + device: str = "cuda", + lr_schedule: str = "cosine", + total_episodes: int = 300, + ): + self.device = torch.device(device if torch.cuda.is_available() else "cpu") + self.state_dim = state_dim + self.num_agents = num_agents + self.num_actions = num_actions + self.edge_feature_dim = edge_feature_dim + self.time_feature_dim = time_feature_dim + self.gamma = gamma + self.gae_lambda = gae_lambda + self.clip_epsilon = clip_epsilon + self.value_coef = value_coef + self.entropy_coef = entropy_coef + self.max_grad_norm = max_grad_norm + self.ppo_epochs = ppo_epochs + self.minibatch_size = minibatch_size + + self.speed_feature_dim = 1 + self.last_reward_dim = 1 + self.agent_id_dim = 1 + self.local_obs_dim = ( + edge_feature_dim + + self.speed_feature_dim + + time_feature_dim + + self.last_reward_dim + + self.agent_id_dim + ) + + self.actor = SharedActor(self.local_obs_dim, num_actions, hidden_dim).to(self.device) + self.critic = CentralizedCritic(state_dim, critic_hidden_dim).to(self.device) + self.optimizer = optim.Adam( + list(self.actor.parameters()) + list(self.critic.parameters()), + lr=learning_rate, + eps=1e-5, + ) + + if lr_schedule == "cosine": + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1 + ) + else: + self.scheduler = None + + agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32) + self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1) + self.reset_buffers() + + def reset_buffers(self): + self.states = [] + self.actions = [] + self.rewards = [] + self.values = [] + self.log_probs = [] + self.dones = [] + + def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor: + if state_tensor.dim() == 1: + state_tensor = state_tensor.unsqueeze(0) + + batch_size = state_tensor.size(0) + edge_block = self.num_agents * self.edge_feature_dim + speed_block_start = edge_block + speed_block_end = speed_block_start + self.num_agents + global_block_start = speed_block_end + global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim + + edge_features = state_tensor[:, :edge_block].view(batch_size, self.num_agents, self.edge_feature_dim) + local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.num_agents, 1) + global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1) + global_features = global_features.expand(-1, self.num_agents, -1) + agent_ids = self.agent_id_features.expand(batch_size, -1, -1) + + return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1) + + def select_action( + self, state: np.ndarray, deterministic: bool = False + ) -> Tuple[np.ndarray, np.ndarray, float]: + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) + local_obs = self._build_local_obs(state_tensor) + + with torch.no_grad(): + logits = self.actor(local_obs.view(self.num_agents, self.local_obs_dim)) + logits = logits.view(1, self.num_agents, self.num_actions) + value = self.critic(state_tensor) + + actions = [] + log_probs = [] + for agent_idx in range(self.num_agents): + dist = torch.distributions.Categorical(logits=logits[0, agent_idx]) + if deterministic: + action = torch.argmax(logits[0, agent_idx], dim=-1).item() + else: + action = dist.sample().item() + actions.append(action) + log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item()) + + return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item() + + def get_value(self, state: np.ndarray) -> float: + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) + with torch.no_grad(): + value = self.critic(state_tensor) + return value.item() + + def store_transition(self, state, action, reward, value, log_prob, done): + self.states.append(state) + self.actions.append(action) + self.rewards.append(reward) + self.values.append(value) + self.log_probs.append(log_prob) + self.dones.append(done) + + def compute_gae(self, next_value: float): + advantages = [] + gae = 0.0 + + for t in reversed(range(len(self.rewards))): + if t == len(self.rewards) - 1: + next_val = next_value + else: + next_val = self.values[t + 1] + + delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t] + gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae + advantages.insert(0, gae) + + advantages = np.array(advantages, dtype=np.float32) + returns = advantages + np.array(self.values, dtype=np.float32) + return advantages, returns + + def update(self, next_value: float) -> Dict[str, float]: + if len(self.states) == 0: + return {} + + advantages, returns = self.compute_gae(next_value) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + states = torch.FloatTensor(np.array(self.states)).to(self.device) + actions = torch.LongTensor(np.array(self.actions)).to(self.device) + old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device) + advantages_t = torch.FloatTensor(advantages).to(self.device) + returns_t = torch.FloatTensor(returns).to(self.device) + + total_policy_loss = 0.0 + total_value_loss = 0.0 + total_entropy = 0.0 + update_count = 0 + + dataset_size = len(self.states) + for _ in range(self.ppo_epochs): + indices = np.random.permutation(dataset_size) + for start_idx in range(0, dataset_size, self.minibatch_size): + end_idx = min(start_idx + self.minibatch_size, dataset_size) + batch_idx = indices[start_idx:end_idx] + + batch_states = states[batch_idx] + batch_actions = actions[batch_idx] + batch_old_lp = old_log_probs[batch_idx] + batch_adv = advantages_t[batch_idx] + batch_ret = returns_t[batch_idx] + + batch_local_obs = self._build_local_obs(batch_states) + logits = self.actor( + batch_local_obs.view(len(batch_idx) * self.num_agents, self.local_obs_dim) + ).view(len(batch_idx), self.num_agents, self.num_actions) + dist = torch.distributions.Categorical(logits=logits) + + new_log_probs = dist.log_prob(batch_actions) + entropy = dist.entropy().mean() + + expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents) + ratio = torch.exp(new_log_probs - batch_old_lp) + surr1 = ratio * expanded_adv + surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv + policy_loss = -torch.min(surr1, surr2).mean() + + values = self.critic(batch_states).squeeze(-1) + value_loss = nn.functional.mse_loss(values, batch_ret) + + loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy + + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_( + list(self.actor.parameters()) + list(self.critic.parameters()), + self.max_grad_norm, + ) + self.optimizer.step() + + total_policy_loss += policy_loss.item() + total_value_loss += value_loss.item() + total_entropy += entropy.item() + update_count += 1 + + if self.scheduler is not None: + self.scheduler.step() + + self.reset_buffers() + return { + "policy_loss": total_policy_loss / max(update_count, 1), + "value_loss": total_value_loss / max(update_count, 1), + "entropy": total_entropy / max(update_count, 1), + } + + def save(self, path: str): + torch.save( + { + "actor_state_dict": self.actor.state_dict(), + "critic_state_dict": self.critic.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + path, + ) + + def load(self, path: str): + checkpoint = torch.load(path, map_location=self.device, weights_only=False) + self.actor.load_state_dict(checkpoint["actor_state_dict"]) + self.critic.load_state_dict(checkpoint["critic_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) diff --git a/config_sumo_vsl.yaml b/config_sumo_vsl.yaml index 780b576..6256c99 100644 --- a/config_sumo_vsl.yaml +++ b/config_sumo_vsl.yaml @@ -113,6 +113,20 @@ agents: batch_size: 15 lr_schedule: "cosine" + mappo: + hidden_dim: 256 + critic_hidden_dim: 256 + learning_rate: 0.0003 + gamma: 0.99 + gae_lambda: 0.95 + clip_epsilon: 0.2 + value_coef: 0.5 + entropy_coef: 0.01 + max_grad_norm: 0.5 + ppo_epochs: 4 + batch_size: 15 + lr_schedule: "cosine" + ddpg: learning_rate: 0.0003 gamma: 0.99 diff --git a/envs/edge_vsl_env.py b/envs/edge_vsl_env.py index 15fce4c..b14e8c2 100644 --- a/envs/edge_vsl_env.py +++ b/envs/edge_vsl_env.py @@ -7,6 +7,7 @@ import os import sys import numpy as np +import xml.etree.ElementTree as ET from typing import Tuple, Dict, List, Optional try: @@ -33,11 +34,18 @@ class SUMOEdgeVSLEnvironment: self.route_file = sumo_cfg["route_file"] self.detector_add_file = sumo_cfg["detector_add_file"] self.enex_add_file = sumo_cfg["enex_add_file"] + self._detector_add_template = self.detector_add_file + self._enex_add_template = self.enex_add_file self.step_length = sumo_cfg["step_length"] self.begin_time = sumo_cfg["begin_time"] self.end_time = sumo_cfg["end_time"] self.use_gui = sumo_cfg.get("gui", False) self.no_warnings = sumo_cfg.get("no_warnings", True) + runtime_cfg = config.get("runtime", {}) + self.runtime_output_dir = runtime_cfg.get("output_dir") + self.runtime_metrics_subdir = runtime_cfg.get("metrics_subdir", "sumo_metrics") + self.runtime_detector_add_file: Optional[str] = None + self.runtime_enex_add_file: Optional[str] = None # 环境参数 self.control_interval = env_cfg["control_interval"] @@ -72,7 +80,7 @@ class SUMOEdgeVSLEnvironment: # 解析网络 self.parser = SUMONetworkParser( - detector_add_file=self.detector_add_file, + detector_add_file=self._detector_add_template, net_file=self.net_file, ) @@ -132,6 +140,8 @@ class SUMOEdgeVSLEnvironment: if self._sumo_running: self._close_sumo() + self._prepare_runtime_additional_files() + binary_name = "sumo-gui" if self.use_gui else "sumo" try: import sumolib @@ -139,9 +149,11 @@ class SUMOEdgeVSLEnvironment: except Exception: sumo_binary = binary_name + detector_add_file = self.runtime_detector_add_file or self.detector_add_file + enex_add_file = self.runtime_enex_add_file or self.enex_add_file cmd = [ sumo_binary, "-n", self.net_file, "-r", self.route_file, - "-a", f"{self.detector_add_file},{self.enex_add_file}", + "-a", f"{detector_add_file},{enex_add_file}", "--step-length", str(self.step_length), "-b", str(self.begin_time), "-e", str(self.end_time), "--collision.action", "warn", "--quit-on-end", "true", @@ -156,6 +168,42 @@ class SUMOEdgeVSLEnvironment: traci.start(cmd, label=f"vsl_{self._episode_count}") self._sumo_running = True + @staticmethod + def _to_sumo_path(path: str) -> str: + return os.path.abspath(path).replace("\\", "/") + + def _rewrite_additional_file(self, template_path: str, runtime_add_path: str, output_xml_path: str): + tree = ET.parse(template_path) + root = tree.getroot() + for elem in root.iter(): + if "file" in elem.attrib: + elem.set("file", output_xml_path) + tree.write(runtime_add_path, encoding="utf-8", xml_declaration=True) + + def _prepare_runtime_additional_files(self): + if not self.runtime_output_dir: + self.runtime_detector_add_file = None + self.runtime_enex_add_file = None + return + + output_dir = os.path.join( + os.path.abspath(self.runtime_output_dir), + self.runtime_metrics_subdir, + ) + os.makedirs(output_dir, exist_ok=True) + + suffix = f"ep{self._episode_count:04d}" + detector_output_file = self._to_sumo_path(os.path.join(output_dir, f"metrics_il_output_{suffix}.xml")) + enex_output_file = self._to_sumo_path(os.path.join(output_dir, f"metrics_enex_output_{suffix}.xml")) + detector_add_file = os.path.join(output_dir, f"runtime_metrics_il_{suffix}.add.xml") + enex_add_file = os.path.join(output_dir, f"runtime_metrics_enex_{suffix}.add.xml") + + self._rewrite_additional_file(self._detector_add_template, detector_add_file, detector_output_file) + self._rewrite_additional_file(self._enex_add_template, enex_add_file, enex_output_file) + + self.runtime_detector_add_file = detector_add_file + self.runtime_enex_add_file = enex_add_file + def _close_sumo(self): if self._sumo_running: try: diff --git a/run_all_training.py b/run_all_training.py index a6b4730..ed10498 100644 --- a/run_all_training.py +++ b/run_all_training.py @@ -3,7 +3,7 @@ import subprocess import sys from datetime import datetime -AGENTS = ["ppo", "appo", "dqn", "ddpg", "td3"] +AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] processes = {} for agent in AGENTS: diff --git a/training/train_appo.py b/training/train_appo.py index d580066..288e701 100644 --- a/training/train_appo.py +++ b/training/train_appo.py @@ -4,6 +4,7 @@ """ import os import sys +import copy import yaml import numpy as np import matplotlib @@ -34,11 +35,13 @@ def train_sumo_appo(): log_dir = os.path.join("logs", "appo", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(config, f) + yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "appo") - env = SUMOEdgeVSLEnvironment(config) + env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_edges @@ -85,6 +88,7 @@ def train_sumo_appo(): episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] + episode_speed_stds = [] episode_hard_brakes = [] policy_losses = [] value_losses = [] @@ -164,6 +168,7 @@ def train_sumo_appo(): episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) if train_stats: @@ -227,7 +232,7 @@ def train_sumo_appo(): # 绘制训练曲线 plot_training_curves( - episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, policy_losses, value_losses, save_path=os.path.join(log_dir, "training_curves.png"), ) diff --git a/training/train_ddpg.py b/training/train_ddpg.py index 0901a57..258c5e0 100644 --- a/training/train_ddpg.py +++ b/training/train_ddpg.py @@ -3,6 +3,7 @@ 使用 Stable-Baselines3 的 DDPG 算法 """ import os +import copy import yaml import numpy as np import matplotlib @@ -30,11 +31,13 @@ def train_sumo_ddpg(): log_dir = os.path.join("logs", "ddpg", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(config, f) + yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "ddpg") - env = SUMOEdgeVSLEnvironment(config) + env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_edges @@ -71,6 +74,7 @@ def train_sumo_ddpg(): episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] + episode_speed_stds = [] episode_hard_brakes = [] best_reward = -float("inf") @@ -128,6 +132,7 @@ def train_sumo_ddpg(): episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) logger.log( @@ -165,7 +170,7 @@ def train_sumo_ddpg(): agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}")) plot_training_curves( - episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, save_path=os.path.join(log_dir, "training_curves.png"), ) diff --git a/training/train_dqn.py b/training/train_dqn.py index 9d477f3..ff07d31 100644 --- a/training/train_dqn.py +++ b/training/train_dqn.py @@ -3,6 +3,7 @@ DQN训练脚本 - SUMO VSL环境 """ import os import sys +import copy import yaml import numpy as np import matplotlib @@ -31,11 +32,13 @@ def train_sumo_dqn(): log_dir = os.path.join("logs", "dqn", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(config, f) + yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "dqn") - env = SUMOEdgeVSLEnvironment(config) + env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim # DQN使用单个网络处理所有边 @@ -76,6 +79,7 @@ def train_sumo_dqn(): episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] + episode_speed_stds = [] episode_hard_brakes = [] losses = [] best_reward = -float("inf") @@ -143,6 +147,7 @@ def train_sumo_dqn(): episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) loss_val = np.mean(losses[-100:]) if losses else None @@ -185,7 +190,7 @@ def train_sumo_dqn(): # 绘制训练曲线 plot_training_curves( - episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, save_path=os.path.join(log_dir, "training_curves.png"), ) diff --git a/training/train_mappo.py b/training/train_mappo.py new file mode 100644 index 0000000..8e48d60 --- /dev/null +++ b/training/train_mappo.py @@ -0,0 +1,238 @@ +""" +MAPPO training script for SUMO + TraCI VSL. +""" +import os +import copy +import yaml +import numpy as np +import matplotlib + +matplotlib.use("Agg") +from datetime import datetime +from tqdm import tqdm + +from envs.edge_vsl_env import SUMOEdgeVSLEnvironment +from agents.mappo_agent import MAPPOAgent +from utils.config import get_agent_config, get_training_config +from utils.logger import TrainingLogger +from utils.plot import plot_training_curves + + +def train_sumo_mappo(): + with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + agent_config = get_agent_config(config, "mappo") + train_config = get_training_config(config) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp) + log_dir = os.path.join("logs", "mappo", timestamp) + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir + with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: + yaml.dump(runtime_config, f) + + logger = TrainingLogger(log_dir, "mappo") + env = SUMOEdgeVSLEnvironment(runtime_config) + + print("=" * 70) + print("MAPPO training - SUMO+TraCI VSL environment") + print("=" * 70) + print(f" State dim: {env.state_dim}") + print(f" Agents: {env.num_edges}") + print(f" Actions per agent: {env.action_dim}") + print(f" Episode steps: {env.episode_length}") + print(f" Control interval: {env.control_interval}s") + print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}") + print(f" LR: {agent_config.get('learning_rate', 3e-4)}") + print(f" Device: {agent_config.get('device', 'cuda')}") + print() + + agent = MAPPOAgent( + state_dim=env.state_dim, + num_agents=env.num_edges, + num_actions=env.action_dim, + edge_feature_dim=env.features_per_edge, + hidden_dim=agent_config.get("hidden_dim", 256), + critic_hidden_dim=agent_config.get("critic_hidden_dim", 256), + learning_rate=agent_config.get("learning_rate", 3e-4), + gamma=agent_config.get("gamma", 0.99), + gae_lambda=agent_config.get("gae_lambda", 0.95), + clip_epsilon=agent_config.get("clip_epsilon", 0.2), + value_coef=agent_config.get("value_coef", 0.5), + entropy_coef=agent_config.get("entropy_coef", 0.01), + max_grad_norm=agent_config.get("max_grad_norm", 0.5), + ppo_epochs=agent_config.get("ppo_epochs", 4), + minibatch_size=agent_config.get("batch_size", 15), + device=agent_config.get("device", "cuda"), + lr_schedule=agent_config.get("lr_schedule", "cosine"), + total_episodes=train_config["num_episodes"], + ) + + num_episodes = train_config["num_episodes"] + save_freq = train_config.get("save_freq", 50) + log_freq = train_config.get("log_freq", 10) + base_seed = train_config.get("random_seed", 42) + + episode_rewards = [] + episode_throughputs = [] + episode_mean_speeds = [] + episode_speed_stds = [] + episode_hard_brakes = [] + policy_losses = [] + value_losses = [] + entropies = [] + best_reward = -float("inf") + + print("Starting training...\n") + + try: + for episode in range(1, num_episodes + 1): + seed = base_seed + episode + state = env.reset(seed=seed) + episode_reward = 0.0 + episode_throughput = 0.0 + episode_speed = 0.0 + episode_speed_std = 0.0 + episode_r_flow = 0.0 + episode_r_var = 0.0 + episode_r_brake = 0.0 + episode_r_penalty = 0.0 + episode_brakes = 0 + done = False + step = 0 + + pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False) + + while not done: + action, log_probs, value = agent.select_action(state, deterministic=False) + next_state, reward, done, info = env.step(action) + + agent.store_transition(state, action, reward, value, log_probs, done) + + episode_reward += reward + episode_throughput += info["throughput"] + episode_speed += info["mean_speed_kmh"] + episode_speed_std += info["speed_std"] * 3.6 + episode_r_flow += info["r_flow"] + episode_r_var += info["r_var"] + episode_r_brake += info["r_brake"] + episode_r_penalty += info["r_penalty"] + episode_brakes += info["num_hard_brakes"] + state = next_state + step += 1 + + pbar.set_postfix( + r=f"{episode_reward:.1f}", + tp=f"{info['throughput']:.0f}", + v=f"{info['mean_speed_kmh']:.1f}", + ) + pbar.update(1) + + pbar.close() + + next_value = 0.0 if done else agent.get_value(next_state) + train_stats = agent.update(next_value) + + avg_tp = episode_throughput / max(step, 1) + avg_speed = episode_speed / max(step, 1) + avg_speed_std = episode_speed_std / max(step, 1) + avg_r_flow = episode_r_flow / max(step, 1) + avg_r_var = episode_r_var / max(step, 1) + avg_r_brake = episode_r_brake / max(step, 1) + avg_r_penalty = episode_r_penalty / max(step, 1) + + episode_rewards.append(episode_reward) + episode_throughputs.append(avg_tp) + episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) + episode_hard_brakes.append(episode_brakes) + + if train_stats: + policy_losses.append(train_stats["policy_loss"]) + value_losses.append(train_stats["value_loss"]) + entropies.append(train_stats["entropy"]) + logger.log( + episode, + episode_reward, + avg_tp, + avg_speed, + speed_std=avg_speed_std, + r_flow=avg_r_flow, + r_var=avg_r_var, + r_brake=avg_r_brake, + r_penalty=avg_r_penalty, + hard_brakes=episode_brakes, + policy_loss=train_stats["policy_loss"], + value_loss=train_stats["value_loss"], + entropy=train_stats["entropy"], + ) + else: + logger.log( + episode, + episode_reward, + avg_tp, + avg_speed, + speed_std=avg_speed_std, + r_flow=avg_r_flow, + r_var=avg_r_var, + r_brake=avg_r_brake, + r_penalty=avg_r_penalty, + hard_brakes=episode_brakes, + ) + + if episode_reward > best_reward: + best_reward = episode_reward + agent.save(os.path.join(checkpoint_dir, "model_best.pt")) + + if episode % log_freq == 0: + recent_rewards = episode_rewards[-log_freq:] + print(f"\nEpisode {episode}/{num_episodes}") + print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})") + print(f" Throughput: {avg_tp:.1f} veh/h") + print(f" Mean Speed: {avg_speed:.1f} km/h") + print(f" Speed Std: {avg_speed_std:.2f} km/h") + print( + f" R(flow/var/brake/pen): " + f"{avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}" + ) + if train_stats: + print(f" Policy Loss: {train_stats['policy_loss']:.4f}") + print(f" Value Loss: {train_stats['value_loss']:.4f}") + print(f" Entropy: {train_stats['entropy']:.4f}") + + if episode % save_freq == 0: + agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt")) + + except KeyboardInterrupt: + print("\nTraining interrupted, saving current model...") + agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt")) + finally: + env.close() + + agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt")) + + plot_training_curves( + episode_rewards, + episode_throughputs, + episode_mean_speeds, + episode_speed_stds, + episode_hard_brakes, + policy_losses, + value_losses, + save_path=os.path.join(log_dir, "training_curves.png"), + ) + + print("=" * 70) + print("Training complete") + print(f" Best reward: {best_reward:.2f}") + print(f" Model dir: {checkpoint_dir}") + print(f" Log dir: {log_dir}") + print("=" * 70) + + +if __name__ == "__main__": + train_sumo_mappo() diff --git a/training/train_ppo.py b/training/train_ppo.py index c58d909..4e86907 100644 --- a/training/train_ppo.py +++ b/training/train_ppo.py @@ -4,6 +4,7 @@ """ import os import sys +import copy import yaml import numpy as np import matplotlib @@ -35,13 +36,15 @@ def train_sumo_ppo(): log_dir = os.path.join("logs", "ppo", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(config, f) + yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "ppo") # 创建环境 - env = SUMOEdgeVSLEnvironment(config) + env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_edges @@ -88,6 +91,7 @@ def train_sumo_ppo(): episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] + episode_speed_stds = [] episode_hard_brakes = [] policy_losses = [] value_losses = [] @@ -167,6 +171,7 @@ def train_sumo_ppo(): episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) if train_stats: @@ -230,7 +235,7 @@ def train_sumo_ppo(): # 绘制训练曲线 plot_training_curves( - episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, policy_losses, value_losses, save_path=os.path.join(log_dir, "training_curves.png"), ) diff --git a/training/train_td3.py b/training/train_td3.py index a368dd7..3e429ee 100644 --- a/training/train_td3.py +++ b/training/train_td3.py @@ -3,6 +3,7 @@ 使用 Stable-Baselines3 的 TD3 算法 """ import os +import copy import yaml import numpy as np import matplotlib @@ -31,11 +32,13 @@ def train_sumo_td3(): log_dir = os.path.join("logs", "td3", timestamp) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + runtime_config = copy.deepcopy(config) + runtime_config.setdefault("runtime", {})["output_dir"] = log_dir with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f: - yaml.dump(config, f) + yaml.dump(runtime_config, f) logger = TrainingLogger(log_dir, "td3") - env = SUMOEdgeVSLEnvironment(config) + env = SUMOEdgeVSLEnvironment(runtime_config) state_dim = env.state_dim action_dims = [env.action_dim] * env.num_edges @@ -73,6 +76,7 @@ def train_sumo_td3(): episode_rewards = [] episode_throughputs = [] episode_mean_speeds = [] + episode_speed_stds = [] episode_hard_brakes = [] best_reward = -float("inf") @@ -138,6 +142,7 @@ def train_sumo_td3(): episode_rewards.append(episode_reward) episode_throughputs.append(avg_tp) episode_mean_speeds.append(avg_speed) + episode_speed_stds.append(avg_speed_std) episode_hard_brakes.append(episode_brakes) logger.log( @@ -175,7 +180,7 @@ def train_sumo_td3(): agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}")) plot_training_curves( - episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes, + episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes, save_path=os.path.join(log_dir, "training_curves.png"), ) diff --git a/utils/plot.py b/utils/plot.py index 9e4f32c..538824b 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -1,15 +1,24 @@ -"""共享训练曲线绘图工具""" +"""Shared training-curve plotting utilities.""" import numpy as np import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt -def plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes, - policy_losses=None, value_losses=None, save_path="training_curves.png"): +def plot_training_curves( + rewards, + throughputs, + mean_speeds, + speed_stds, + hard_brakes, + policy_losses=None, + value_losses=None, + save_path="training_curves.png", +): window = 20 has_losses = policy_losses and value_losses - ncols = 4 if has_losses else 2 + ncols = 4 if has_losses else 3 nrows = 2 fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10)) @@ -23,26 +32,50 @@ def plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes, ax.set_title(title) ax.grid(True, alpha=0.3) + summary = ( + f"Episodes: {len(rewards)}\n" + f"Best: {max(rewards):.2f}\n" + f"Avg(last20): {np.mean(rewards[-20:]):.2f}" + ) + if has_losses: _plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward") _plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)") _plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)") - _plot(axes[0, 3], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count") - axes[1, 0].plot(policy_losses, "b-", alpha=0.6) - axes[1, 0].set_title("Policy Loss"); axes[1, 0].grid(True, alpha=0.3) - axes[1, 1].plot(value_losses, "r-", alpha=0.6) - axes[1, 1].set_title("Value Loss"); axes[1, 1].grid(True, alpha=0.3) - summary = (f"Episodes: {len(rewards)}\nBest: {max(rewards):.2f}\n" - f"Avg(last20): {np.mean(rewards[-20:]):.2f}") - axes[1, 2].axis("off") - axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace", - verticalalignment="center", transform=axes[1, 2].transAxes) + _plot(axes[0, 3], speed_stds, "purple", "Speed Std", "Speed Std (km/h)") + _plot(axes[1, 0], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count") + axes[1, 1].plot(policy_losses, "b-", alpha=0.6) + axes[1, 1].set_title("Policy Loss") + axes[1, 1].grid(True, alpha=0.3) + axes[1, 2].plot(value_losses, "r-", alpha=0.6) + axes[1, 2].set_title("Value Loss") + axes[1, 2].grid(True, alpha=0.3) axes[1, 3].axis("off") + axes[1, 3].text( + 0.1, + 0.5, + summary, + fontsize=12, + family="monospace", + verticalalignment="center", + transform=axes[1, 3].transAxes, + ) else: _plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward") _plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)") - _plot(axes[1, 0], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)") + _plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)") + _plot(axes[1, 0], speed_stds, "purple", "Speed Std", "Speed Std (km/h)") _plot(axes[1, 1], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count") + axes[1, 2].axis("off") + axes[1, 2].text( + 0.1, + 0.5, + summary, + fontsize=12, + family="monospace", + verticalalignment="center", + transform=axes[1, 2].transAxes, + ) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight")