增加MAPPO对比模型,修复sumo仿真时车检器记录文件问题
This commit is contained in:
parent
87d292b2b0
commit
43c7cb1dfa
|
|
@ -0,0 +1,306 @@
|
|||
"""
|
||||
MAPPO agent for SUMO VSL.
|
||||
|
||||
This implementation uses parameter sharing across edge-agents:
|
||||
- Actor: decentralized, one shared policy over per-edge local observations
|
||||
- Critic: centralized, one value head over the global state
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
class SharedActor(nn.Module):
|
||||
def __init__(self, local_obs_dim: int, num_actions: int, hidden_dim: int = 256):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(local_obs_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, num_actions),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.orthogonal_(self.net[-1].weight, gain=0.01)
|
||||
|
||||
def forward(self, local_obs: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(local_obs)
|
||||
|
||||
|
||||
class CentralizedCritic(nn.Module):
|
||||
def __init__(self, state_dim: int, hidden_dim: int = 256):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(state_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 1),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.orthogonal_(self.net[-1].weight, gain=1.0)
|
||||
|
||||
def forward(self, state: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(state)
|
||||
|
||||
|
||||
class MAPPOAgent:
|
||||
"""Parameter-sharing MAPPO for edge-wise VSL control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
num_agents: int,
|
||||
num_actions: int,
|
||||
edge_feature_dim: int = 3,
|
||||
time_feature_dim: int = 3,
|
||||
hidden_dim: int = 256,
|
||||
critic_hidden_dim: int = 256,
|
||||
learning_rate: float = 3e-4,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_epsilon: float = 0.2,
|
||||
value_coef: float = 0.5,
|
||||
entropy_coef: float = 0.01,
|
||||
max_grad_norm: float = 0.5,
|
||||
ppo_epochs: int = 4,
|
||||
minibatch_size: int = 15,
|
||||
device: str = "cuda",
|
||||
lr_schedule: str = "cosine",
|
||||
total_episodes: int = 300,
|
||||
):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
self.state_dim = state_dim
|
||||
self.num_agents = num_agents
|
||||
self.num_actions = num_actions
|
||||
self.edge_feature_dim = edge_feature_dim
|
||||
self.time_feature_dim = time_feature_dim
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.clip_epsilon = clip_epsilon
|
||||
self.value_coef = value_coef
|
||||
self.entropy_coef = entropy_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.ppo_epochs = ppo_epochs
|
||||
self.minibatch_size = minibatch_size
|
||||
|
||||
self.speed_feature_dim = 1
|
||||
self.last_reward_dim = 1
|
||||
self.agent_id_dim = 1
|
||||
self.local_obs_dim = (
|
||||
edge_feature_dim
|
||||
+ self.speed_feature_dim
|
||||
+ time_feature_dim
|
||||
+ self.last_reward_dim
|
||||
+ self.agent_id_dim
|
||||
)
|
||||
|
||||
self.actor = SharedActor(self.local_obs_dim, num_actions, hidden_dim).to(self.device)
|
||||
self.critic = CentralizedCritic(state_dim, critic_hidden_dim).to(self.device)
|
||||
self.optimizer = optim.Adam(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
lr=learning_rate,
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
if lr_schedule == "cosine":
|
||||
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer, T_max=total_episodes, eta_min=learning_rate * 0.1
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
agent_ids = np.linspace(0.0, 1.0, num_agents, dtype=np.float32)
|
||||
self.agent_id_features = torch.tensor(agent_ids, device=self.device).view(1, num_agents, 1)
|
||||
self.reset_buffers()
|
||||
|
||||
def reset_buffers(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.log_probs = []
|
||||
self.dones = []
|
||||
|
||||
def _build_local_obs(self, state_tensor: torch.Tensor) -> torch.Tensor:
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
batch_size = state_tensor.size(0)
|
||||
edge_block = self.num_agents * self.edge_feature_dim
|
||||
speed_block_start = edge_block
|
||||
speed_block_end = speed_block_start + self.num_agents
|
||||
global_block_start = speed_block_end
|
||||
global_block_end = global_block_start + self.time_feature_dim + self.last_reward_dim
|
||||
|
||||
edge_features = state_tensor[:, :edge_block].view(batch_size, self.num_agents, self.edge_feature_dim)
|
||||
local_speed_limits = state_tensor[:, speed_block_start:speed_block_end].view(batch_size, self.num_agents, 1)
|
||||
global_features = state_tensor[:, global_block_start:global_block_end].unsqueeze(1)
|
||||
global_features = global_features.expand(-1, self.num_agents, -1)
|
||||
agent_ids = self.agent_id_features.expand(batch_size, -1, -1)
|
||||
|
||||
return torch.cat([edge_features, local_speed_limits, global_features, agent_ids], dim=-1)
|
||||
|
||||
def select_action(
|
||||
self, state: np.ndarray, deterministic: bool = False
|
||||
) -> Tuple[np.ndarray, np.ndarray, float]:
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
local_obs = self._build_local_obs(state_tensor)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.actor(local_obs.view(self.num_agents, self.local_obs_dim))
|
||||
logits = logits.view(1, self.num_agents, self.num_actions)
|
||||
value = self.critic(state_tensor)
|
||||
|
||||
actions = []
|
||||
log_probs = []
|
||||
for agent_idx in range(self.num_agents):
|
||||
dist = torch.distributions.Categorical(logits=logits[0, agent_idx])
|
||||
if deterministic:
|
||||
action = torch.argmax(logits[0, agent_idx], dim=-1).item()
|
||||
else:
|
||||
action = dist.sample().item()
|
||||
actions.append(action)
|
||||
log_probs.append(dist.log_prob(torch.tensor(action, device=self.device)).item())
|
||||
|
||||
return np.array(actions, dtype=np.int64), np.array(log_probs, dtype=np.float32), value.item()
|
||||
|
||||
def get_value(self, state: np.ndarray) -> float:
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
with torch.no_grad():
|
||||
value = self.critic(state_tensor)
|
||||
return value.item()
|
||||
|
||||
def store_transition(self, state, action, reward, value, log_prob, done):
|
||||
self.states.append(state)
|
||||
self.actions.append(action)
|
||||
self.rewards.append(reward)
|
||||
self.values.append(value)
|
||||
self.log_probs.append(log_prob)
|
||||
self.dones.append(done)
|
||||
|
||||
def compute_gae(self, next_value: float):
|
||||
advantages = []
|
||||
gae = 0.0
|
||||
|
||||
for t in reversed(range(len(self.rewards))):
|
||||
if t == len(self.rewards) - 1:
|
||||
next_val = next_value
|
||||
else:
|
||||
next_val = self.values[t + 1]
|
||||
|
||||
delta = self.rewards[t] + self.gamma * next_val * (1 - self.dones[t]) - self.values[t]
|
||||
gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
|
||||
advantages.insert(0, gae)
|
||||
|
||||
advantages = np.array(advantages, dtype=np.float32)
|
||||
returns = advantages + np.array(self.values, dtype=np.float32)
|
||||
return advantages, returns
|
||||
|
||||
def update(self, next_value: float) -> Dict[str, float]:
|
||||
if len(self.states) == 0:
|
||||
return {}
|
||||
|
||||
advantages, returns = self.compute_gae(next_value)
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
states = torch.FloatTensor(np.array(self.states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(self.actions)).to(self.device)
|
||||
old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device)
|
||||
advantages_t = torch.FloatTensor(advantages).to(self.device)
|
||||
returns_t = torch.FloatTensor(returns).to(self.device)
|
||||
|
||||
total_policy_loss = 0.0
|
||||
total_value_loss = 0.0
|
||||
total_entropy = 0.0
|
||||
update_count = 0
|
||||
|
||||
dataset_size = len(self.states)
|
||||
for _ in range(self.ppo_epochs):
|
||||
indices = np.random.permutation(dataset_size)
|
||||
for start_idx in range(0, dataset_size, self.minibatch_size):
|
||||
end_idx = min(start_idx + self.minibatch_size, dataset_size)
|
||||
batch_idx = indices[start_idx:end_idx]
|
||||
|
||||
batch_states = states[batch_idx]
|
||||
batch_actions = actions[batch_idx]
|
||||
batch_old_lp = old_log_probs[batch_idx]
|
||||
batch_adv = advantages_t[batch_idx]
|
||||
batch_ret = returns_t[batch_idx]
|
||||
|
||||
batch_local_obs = self._build_local_obs(batch_states)
|
||||
logits = self.actor(
|
||||
batch_local_obs.view(len(batch_idx) * self.num_agents, self.local_obs_dim)
|
||||
).view(len(batch_idx), self.num_agents, self.num_actions)
|
||||
dist = torch.distributions.Categorical(logits=logits)
|
||||
|
||||
new_log_probs = dist.log_prob(batch_actions)
|
||||
entropy = dist.entropy().mean()
|
||||
|
||||
expanded_adv = batch_adv.unsqueeze(1).expand(-1, self.num_agents)
|
||||
ratio = torch.exp(new_log_probs - batch_old_lp)
|
||||
surr1 = ratio * expanded_adv
|
||||
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * expanded_adv
|
||||
policy_loss = -torch.min(surr1, surr2).mean()
|
||||
|
||||
values = self.critic(batch_states).squeeze(-1)
|
||||
value_loss = nn.functional.mse_loss(values, batch_ret)
|
||||
|
||||
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
self.max_grad_norm,
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
total_policy_loss += policy_loss.item()
|
||||
total_value_loss += value_loss.item()
|
||||
total_entropy += entropy.item()
|
||||
update_count += 1
|
||||
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step()
|
||||
|
||||
self.reset_buffers()
|
||||
return {
|
||||
"policy_loss": total_policy_loss / max(update_count, 1),
|
||||
"value_loss": total_value_loss / max(update_count, 1),
|
||||
"entropy": total_entropy / max(update_count, 1),
|
||||
}
|
||||
|
||||
def save(self, path: str):
|
||||
torch.save(
|
||||
{
|
||||
"actor_state_dict": self.actor.state_dict(),
|
||||
"critic_state_dict": self.critic.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
},
|
||||
path,
|
||||
)
|
||||
|
||||
def load(self, path: str):
|
||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||
self.actor.load_state_dict(checkpoint["actor_state_dict"])
|
||||
self.critic.load_state_dict(checkpoint["critic_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
|
@ -113,6 +113,20 @@ agents:
|
|||
batch_size: 15
|
||||
lr_schedule: "cosine"
|
||||
|
||||
mappo:
|
||||
hidden_dim: 256
|
||||
critic_hidden_dim: 256
|
||||
learning_rate: 0.0003
|
||||
gamma: 0.99
|
||||
gae_lambda: 0.95
|
||||
clip_epsilon: 0.2
|
||||
value_coef: 0.5
|
||||
entropy_coef: 0.01
|
||||
max_grad_norm: 0.5
|
||||
ppo_epochs: 4
|
||||
batch_size: 15
|
||||
lr_schedule: "cosine"
|
||||
|
||||
ddpg:
|
||||
learning_rate: 0.0003
|
||||
gamma: 0.99
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
|
||||
try:
|
||||
|
|
@ -33,11 +34,18 @@ class SUMOEdgeVSLEnvironment:
|
|||
self.route_file = sumo_cfg["route_file"]
|
||||
self.detector_add_file = sumo_cfg["detector_add_file"]
|
||||
self.enex_add_file = sumo_cfg["enex_add_file"]
|
||||
self._detector_add_template = self.detector_add_file
|
||||
self._enex_add_template = self.enex_add_file
|
||||
self.step_length = sumo_cfg["step_length"]
|
||||
self.begin_time = sumo_cfg["begin_time"]
|
||||
self.end_time = sumo_cfg["end_time"]
|
||||
self.use_gui = sumo_cfg.get("gui", False)
|
||||
self.no_warnings = sumo_cfg.get("no_warnings", True)
|
||||
runtime_cfg = config.get("runtime", {})
|
||||
self.runtime_output_dir = runtime_cfg.get("output_dir")
|
||||
self.runtime_metrics_subdir = runtime_cfg.get("metrics_subdir", "sumo_metrics")
|
||||
self.runtime_detector_add_file: Optional[str] = None
|
||||
self.runtime_enex_add_file: Optional[str] = None
|
||||
|
||||
# 环境参数
|
||||
self.control_interval = env_cfg["control_interval"]
|
||||
|
|
@ -72,7 +80,7 @@ class SUMOEdgeVSLEnvironment:
|
|||
|
||||
# 解析网络
|
||||
self.parser = SUMONetworkParser(
|
||||
detector_add_file=self.detector_add_file,
|
||||
detector_add_file=self._detector_add_template,
|
||||
net_file=self.net_file,
|
||||
)
|
||||
|
||||
|
|
@ -132,6 +140,8 @@ class SUMOEdgeVSLEnvironment:
|
|||
if self._sumo_running:
|
||||
self._close_sumo()
|
||||
|
||||
self._prepare_runtime_additional_files()
|
||||
|
||||
binary_name = "sumo-gui" if self.use_gui else "sumo"
|
||||
try:
|
||||
import sumolib
|
||||
|
|
@ -139,9 +149,11 @@ class SUMOEdgeVSLEnvironment:
|
|||
except Exception:
|
||||
sumo_binary = binary_name
|
||||
|
||||
detector_add_file = self.runtime_detector_add_file or self.detector_add_file
|
||||
enex_add_file = self.runtime_enex_add_file or self.enex_add_file
|
||||
cmd = [
|
||||
sumo_binary, "-n", self.net_file, "-r", self.route_file,
|
||||
"-a", f"{self.detector_add_file},{self.enex_add_file}",
|
||||
"-a", f"{detector_add_file},{enex_add_file}",
|
||||
"--step-length", str(self.step_length),
|
||||
"-b", str(self.begin_time), "-e", str(self.end_time),
|
||||
"--collision.action", "warn", "--quit-on-end", "true",
|
||||
|
|
@ -156,6 +168,42 @@ class SUMOEdgeVSLEnvironment:
|
|||
traci.start(cmd, label=f"vsl_{self._episode_count}")
|
||||
self._sumo_running = True
|
||||
|
||||
@staticmethod
|
||||
def _to_sumo_path(path: str) -> str:
|
||||
return os.path.abspath(path).replace("\\", "/")
|
||||
|
||||
def _rewrite_additional_file(self, template_path: str, runtime_add_path: str, output_xml_path: str):
|
||||
tree = ET.parse(template_path)
|
||||
root = tree.getroot()
|
||||
for elem in root.iter():
|
||||
if "file" in elem.attrib:
|
||||
elem.set("file", output_xml_path)
|
||||
tree.write(runtime_add_path, encoding="utf-8", xml_declaration=True)
|
||||
|
||||
def _prepare_runtime_additional_files(self):
|
||||
if not self.runtime_output_dir:
|
||||
self.runtime_detector_add_file = None
|
||||
self.runtime_enex_add_file = None
|
||||
return
|
||||
|
||||
output_dir = os.path.join(
|
||||
os.path.abspath(self.runtime_output_dir),
|
||||
self.runtime_metrics_subdir,
|
||||
)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
suffix = f"ep{self._episode_count:04d}"
|
||||
detector_output_file = self._to_sumo_path(os.path.join(output_dir, f"metrics_il_output_{suffix}.xml"))
|
||||
enex_output_file = self._to_sumo_path(os.path.join(output_dir, f"metrics_enex_output_{suffix}.xml"))
|
||||
detector_add_file = os.path.join(output_dir, f"runtime_metrics_il_{suffix}.add.xml")
|
||||
enex_add_file = os.path.join(output_dir, f"runtime_metrics_enex_{suffix}.add.xml")
|
||||
|
||||
self._rewrite_additional_file(self._detector_add_template, detector_add_file, detector_output_file)
|
||||
self._rewrite_additional_file(self._enex_add_template, enex_add_file, enex_output_file)
|
||||
|
||||
self.runtime_detector_add_file = detector_add_file
|
||||
self.runtime_enex_add_file = enex_add_file
|
||||
|
||||
def _close_sumo(self):
|
||||
if self._sumo_running:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import subprocess
|
|||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
AGENTS = ["ppo", "appo", "dqn", "ddpg", "td3"]
|
||||
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
||||
|
||||
processes = {}
|
||||
for agent in AGENTS:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
"""
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
|
@ -34,11 +35,13 @@ def train_sumo_appo():
|
|||
log_dir = os.path.join("logs", "appo", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "appo")
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
||||
state_dim = env.state_dim
|
||||
action_dims = [env.action_dim] * env.num_edges
|
||||
|
|
@ -85,6 +88,7 @@ def train_sumo_appo():
|
|||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
policy_losses = []
|
||||
value_losses = []
|
||||
|
|
@ -164,6 +168,7 @@ def train_sumo_appo():
|
|||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
if train_stats:
|
||||
|
|
@ -227,7 +232,7 @@ def train_sumo_appo():
|
|||
|
||||
# 绘制训练曲线
|
||||
plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
||||
policy_losses, value_losses,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
使用 Stable-Baselines3 的 DDPG 算法
|
||||
"""
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
|
@ -30,11 +31,13 @@ def train_sumo_ddpg():
|
|||
log_dir = os.path.join("logs", "ddpg", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "ddpg")
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
||||
state_dim = env.state_dim
|
||||
action_dims = [env.action_dim] * env.num_edges
|
||||
|
|
@ -71,6 +74,7 @@ def train_sumo_ddpg():
|
|||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
|
|
@ -128,6 +132,7 @@ def train_sumo_ddpg():
|
|||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
logger.log(
|
||||
|
|
@ -165,7 +170,7 @@ def train_sumo_ddpg():
|
|||
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
|
||||
|
||||
plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ DQN训练脚本 - SUMO VSL环境
|
|||
"""
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
|
@ -31,11 +32,13 @@ def train_sumo_dqn():
|
|||
log_dir = os.path.join("logs", "dqn", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "dqn")
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
state_dim = env.state_dim
|
||||
|
||||
# DQN使用单个网络处理所有边
|
||||
|
|
@ -76,6 +79,7 @@ def train_sumo_dqn():
|
|||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
losses = []
|
||||
best_reward = -float("inf")
|
||||
|
|
@ -143,6 +147,7 @@ def train_sumo_dqn():
|
|||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
loss_val = np.mean(losses[-100:]) if losses else None
|
||||
|
|
@ -185,7 +190,7 @@ def train_sumo_dqn():
|
|||
|
||||
# 绘制训练曲线
|
||||
plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,238 @@
|
|||
"""
|
||||
MAPPO training script for SUMO + TraCI VSL.
|
||||
"""
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
|
||||
from envs.edge_vsl_env import SUMOEdgeVSLEnvironment
|
||||
from agents.mappo_agent import MAPPOAgent
|
||||
from utils.config import get_agent_config, get_training_config
|
||||
from utils.logger import TrainingLogger
|
||||
from utils.plot import plot_training_curves
|
||||
|
||||
|
||||
def train_sumo_mappo():
|
||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
agent_config = get_agent_config(config, "mappo")
|
||||
train_config = get_training_config(config)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp)
|
||||
log_dir = os.path.join("logs", "mappo", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "mappo")
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
||||
print("=" * 70)
|
||||
print("MAPPO training - SUMO+TraCI VSL environment")
|
||||
print("=" * 70)
|
||||
print(f" State dim: {env.state_dim}")
|
||||
print(f" Agents: {env.num_edges}")
|
||||
print(f" Actions per agent: {env.action_dim}")
|
||||
print(f" Episode steps: {env.episode_length}")
|
||||
print(f" Control interval: {env.control_interval}s")
|
||||
print(f" Hidden dim: {agent_config.get('hidden_dim', 256)}")
|
||||
print(f" LR: {agent_config.get('learning_rate', 3e-4)}")
|
||||
print(f" Device: {agent_config.get('device', 'cuda')}")
|
||||
print()
|
||||
|
||||
agent = MAPPOAgent(
|
||||
state_dim=env.state_dim,
|
||||
num_agents=env.num_edges,
|
||||
num_actions=env.action_dim,
|
||||
edge_feature_dim=env.features_per_edge,
|
||||
hidden_dim=agent_config.get("hidden_dim", 256),
|
||||
critic_hidden_dim=agent_config.get("critic_hidden_dim", 256),
|
||||
learning_rate=agent_config.get("learning_rate", 3e-4),
|
||||
gamma=agent_config.get("gamma", 0.99),
|
||||
gae_lambda=agent_config.get("gae_lambda", 0.95),
|
||||
clip_epsilon=agent_config.get("clip_epsilon", 0.2),
|
||||
value_coef=agent_config.get("value_coef", 0.5),
|
||||
entropy_coef=agent_config.get("entropy_coef", 0.01),
|
||||
max_grad_norm=agent_config.get("max_grad_norm", 0.5),
|
||||
ppo_epochs=agent_config.get("ppo_epochs", 4),
|
||||
minibatch_size=agent_config.get("batch_size", 15),
|
||||
device=agent_config.get("device", "cuda"),
|
||||
lr_schedule=agent_config.get("lr_schedule", "cosine"),
|
||||
total_episodes=train_config["num_episodes"],
|
||||
)
|
||||
|
||||
num_episodes = train_config["num_episodes"]
|
||||
save_freq = train_config.get("save_freq", 50)
|
||||
log_freq = train_config.get("log_freq", 10)
|
||||
base_seed = train_config.get("random_seed", 42)
|
||||
|
||||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
policy_losses = []
|
||||
value_losses = []
|
||||
entropies = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
print("Starting training...\n")
|
||||
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
seed = base_seed + episode
|
||||
state = env.reset(seed=seed)
|
||||
episode_reward = 0.0
|
||||
episode_throughput = 0.0
|
||||
episode_speed = 0.0
|
||||
episode_speed_std = 0.0
|
||||
episode_r_flow = 0.0
|
||||
episode_r_var = 0.0
|
||||
episode_r_brake = 0.0
|
||||
episode_r_penalty = 0.0
|
||||
episode_brakes = 0
|
||||
done = False
|
||||
step = 0
|
||||
|
||||
pbar = tqdm(total=env.episode_length, desc=f"Ep {episode}/{num_episodes}", leave=False)
|
||||
|
||||
while not done:
|
||||
action, log_probs, value = agent.select_action(state, deterministic=False)
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
agent.store_transition(state, action, reward, value, log_probs, done)
|
||||
|
||||
episode_reward += reward
|
||||
episode_throughput += info["throughput"]
|
||||
episode_speed += info["mean_speed_kmh"]
|
||||
episode_speed_std += info["speed_std"] * 3.6
|
||||
episode_r_flow += info["r_flow"]
|
||||
episode_r_var += info["r_var"]
|
||||
episode_r_brake += info["r_brake"]
|
||||
episode_r_penalty += info["r_penalty"]
|
||||
episode_brakes += info["num_hard_brakes"]
|
||||
state = next_state
|
||||
step += 1
|
||||
|
||||
pbar.set_postfix(
|
||||
r=f"{episode_reward:.1f}",
|
||||
tp=f"{info['throughput']:.0f}",
|
||||
v=f"{info['mean_speed_kmh']:.1f}",
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
next_value = 0.0 if done else agent.get_value(next_state)
|
||||
train_stats = agent.update(next_value)
|
||||
|
||||
avg_tp = episode_throughput / max(step, 1)
|
||||
avg_speed = episode_speed / max(step, 1)
|
||||
avg_speed_std = episode_speed_std / max(step, 1)
|
||||
avg_r_flow = episode_r_flow / max(step, 1)
|
||||
avg_r_var = episode_r_var / max(step, 1)
|
||||
avg_r_brake = episode_r_brake / max(step, 1)
|
||||
avg_r_penalty = episode_r_penalty / max(step, 1)
|
||||
|
||||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
if train_stats:
|
||||
policy_losses.append(train_stats["policy_loss"])
|
||||
value_losses.append(train_stats["value_loss"])
|
||||
entropies.append(train_stats["entropy"])
|
||||
logger.log(
|
||||
episode,
|
||||
episode_reward,
|
||||
avg_tp,
|
||||
avg_speed,
|
||||
speed_std=avg_speed_std,
|
||||
r_flow=avg_r_flow,
|
||||
r_var=avg_r_var,
|
||||
r_brake=avg_r_brake,
|
||||
r_penalty=avg_r_penalty,
|
||||
hard_brakes=episode_brakes,
|
||||
policy_loss=train_stats["policy_loss"],
|
||||
value_loss=train_stats["value_loss"],
|
||||
entropy=train_stats["entropy"],
|
||||
)
|
||||
else:
|
||||
logger.log(
|
||||
episode,
|
||||
episode_reward,
|
||||
avg_tp,
|
||||
avg_speed,
|
||||
speed_std=avg_speed_std,
|
||||
r_flow=avg_r_flow,
|
||||
r_var=avg_r_var,
|
||||
r_brake=avg_r_brake,
|
||||
r_penalty=avg_r_penalty,
|
||||
hard_brakes=episode_brakes,
|
||||
)
|
||||
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
agent.save(os.path.join(checkpoint_dir, "model_best.pt"))
|
||||
|
||||
if episode % log_freq == 0:
|
||||
recent_rewards = episode_rewards[-log_freq:]
|
||||
print(f"\nEpisode {episode}/{num_episodes}")
|
||||
print(f" Reward: {episode_reward:.2f} (Avg: {np.mean(recent_rewards):.2f})")
|
||||
print(f" Throughput: {avg_tp:.1f} veh/h")
|
||||
print(f" Mean Speed: {avg_speed:.1f} km/h")
|
||||
print(f" Speed Std: {avg_speed_std:.2f} km/h")
|
||||
print(
|
||||
f" R(flow/var/brake/pen): "
|
||||
f"{avg_r_flow:.3f} / {avg_r_var:.3f} / {avg_r_brake:.3f} / {avg_r_penalty:.3f}"
|
||||
)
|
||||
if train_stats:
|
||||
print(f" Policy Loss: {train_stats['policy_loss']:.4f}")
|
||||
print(f" Value Loss: {train_stats['value_loss']:.4f}")
|
||||
print(f" Entropy: {train_stats['entropy']:.4f}")
|
||||
|
||||
if episode % save_freq == 0:
|
||||
agent.save(os.path.join(checkpoint_dir, f"model_ep{episode}.pt"))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nTraining interrupted, saving current model...")
|
||||
agent.save(os.path.join(checkpoint_dir, "model_interrupted.pt"))
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}.pt"))
|
||||
|
||||
plot_training_curves(
|
||||
episode_rewards,
|
||||
episode_throughputs,
|
||||
episode_mean_speeds,
|
||||
episode_speed_stds,
|
||||
episode_hard_brakes,
|
||||
policy_losses,
|
||||
value_losses,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
||||
print("=" * 70)
|
||||
print("Training complete")
|
||||
print(f" Best reward: {best_reward:.2f}")
|
||||
print(f" Model dir: {checkpoint_dir}")
|
||||
print(f" Log dir: {log_dir}")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_sumo_mappo()
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
"""
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
|
@ -35,13 +36,15 @@ def train_sumo_ppo():
|
|||
log_dir = os.path.join("logs", "ppo", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "ppo")
|
||||
|
||||
# 创建环境
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
||||
state_dim = env.state_dim
|
||||
action_dims = [env.action_dim] * env.num_edges
|
||||
|
|
@ -88,6 +91,7 @@ def train_sumo_ppo():
|
|||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
policy_losses = []
|
||||
value_losses = []
|
||||
|
|
@ -167,6 +171,7 @@ def train_sumo_ppo():
|
|||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
if train_stats:
|
||||
|
|
@ -230,7 +235,7 @@ def train_sumo_ppo():
|
|||
|
||||
# 绘制训练曲线
|
||||
plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
||||
policy_losses, value_losses,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
使用 Stable-Baselines3 的 TD3 算法
|
||||
"""
|
||||
import os
|
||||
import copy
|
||||
import yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
|
@ -31,11 +32,13 @@ def train_sumo_td3():
|
|||
log_dir = os.path.join("logs", "td3", timestamp)
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
runtime_config = copy.deepcopy(config)
|
||||
runtime_config.setdefault("runtime", {})["output_dir"] = log_dir
|
||||
with open(os.path.join(checkpoint_dir, "config.yaml"), "w", encoding="utf-8") as f:
|
||||
yaml.dump(config, f)
|
||||
yaml.dump(runtime_config, f)
|
||||
|
||||
logger = TrainingLogger(log_dir, "td3")
|
||||
env = SUMOEdgeVSLEnvironment(config)
|
||||
env = SUMOEdgeVSLEnvironment(runtime_config)
|
||||
|
||||
state_dim = env.state_dim
|
||||
action_dims = [env.action_dim] * env.num_edges
|
||||
|
|
@ -73,6 +76,7 @@ def train_sumo_td3():
|
|||
episode_rewards = []
|
||||
episode_throughputs = []
|
||||
episode_mean_speeds = []
|
||||
episode_speed_stds = []
|
||||
episode_hard_brakes = []
|
||||
best_reward = -float("inf")
|
||||
|
||||
|
|
@ -138,6 +142,7 @@ def train_sumo_td3():
|
|||
episode_rewards.append(episode_reward)
|
||||
episode_throughputs.append(avg_tp)
|
||||
episode_mean_speeds.append(avg_speed)
|
||||
episode_speed_stds.append(avg_speed_std)
|
||||
episode_hard_brakes.append(episode_brakes)
|
||||
|
||||
logger.log(
|
||||
|
|
@ -175,7 +180,7 @@ def train_sumo_td3():
|
|||
agent.save(os.path.join(checkpoint_dir, f"model_ep{num_episodes}"))
|
||||
|
||||
plot_training_curves(
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_hard_brakes,
|
||||
episode_rewards, episode_throughputs, episode_mean_speeds, episode_speed_stds, episode_hard_brakes,
|
||||
save_path=os.path.join(log_dir, "training_curves.png"),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,24 @@
|
|||
"""共享训练曲线绘图工具"""
|
||||
"""Shared training-curve plotting utilities."""
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes,
|
||||
policy_losses=None, value_losses=None, save_path="training_curves.png"):
|
||||
def plot_training_curves(
|
||||
rewards,
|
||||
throughputs,
|
||||
mean_speeds,
|
||||
speed_stds,
|
||||
hard_brakes,
|
||||
policy_losses=None,
|
||||
value_losses=None,
|
||||
save_path="training_curves.png",
|
||||
):
|
||||
window = 20
|
||||
has_losses = policy_losses and value_losses
|
||||
ncols = 4 if has_losses else 2
|
||||
ncols = 4 if has_losses else 3
|
||||
nrows = 2
|
||||
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 10))
|
||||
|
||||
|
|
@ -23,26 +32,50 @@ def plot_training_curves(rewards, throughputs, mean_speeds, hard_brakes,
|
|||
ax.set_title(title)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
summary = (
|
||||
f"Episodes: {len(rewards)}\n"
|
||||
f"Best: {max(rewards):.2f}\n"
|
||||
f"Avg(last20): {np.mean(rewards[-20:]):.2f}"
|
||||
)
|
||||
|
||||
if has_losses:
|
||||
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
|
||||
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
|
||||
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
|
||||
_plot(axes[0, 3], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count")
|
||||
axes[1, 0].plot(policy_losses, "b-", alpha=0.6)
|
||||
axes[1, 0].set_title("Policy Loss"); axes[1, 0].grid(True, alpha=0.3)
|
||||
axes[1, 1].plot(value_losses, "r-", alpha=0.6)
|
||||
axes[1, 1].set_title("Value Loss"); axes[1, 1].grid(True, alpha=0.3)
|
||||
summary = (f"Episodes: {len(rewards)}\nBest: {max(rewards):.2f}\n"
|
||||
f"Avg(last20): {np.mean(rewards[-20:]):.2f}")
|
||||
axes[1, 2].axis("off")
|
||||
axes[1, 2].text(0.1, 0.5, summary, fontsize=12, family="monospace",
|
||||
verticalalignment="center", transform=axes[1, 2].transAxes)
|
||||
_plot(axes[0, 3], speed_stds, "purple", "Speed Std", "Speed Std (km/h)")
|
||||
_plot(axes[1, 0], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count")
|
||||
axes[1, 1].plot(policy_losses, "b-", alpha=0.6)
|
||||
axes[1, 1].set_title("Policy Loss")
|
||||
axes[1, 1].grid(True, alpha=0.3)
|
||||
axes[1, 2].plot(value_losses, "r-", alpha=0.6)
|
||||
axes[1, 2].set_title("Value Loss")
|
||||
axes[1, 2].grid(True, alpha=0.3)
|
||||
axes[1, 3].axis("off")
|
||||
axes[1, 3].text(
|
||||
0.1,
|
||||
0.5,
|
||||
summary,
|
||||
fontsize=12,
|
||||
family="monospace",
|
||||
verticalalignment="center",
|
||||
transform=axes[1, 3].transAxes,
|
||||
)
|
||||
else:
|
||||
_plot(axes[0, 0], rewards, "blue", "Episode Reward", "Total Reward")
|
||||
_plot(axes[0, 1], throughputs, "green", "Throughput", "Avg Throughput (veh/h)")
|
||||
_plot(axes[1, 0], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
|
||||
_plot(axes[0, 2], mean_speeds, "orange", "Mean Speed", "Mean Speed (km/h)")
|
||||
_plot(axes[1, 0], speed_stds, "purple", "Speed Std", "Speed Std (km/h)")
|
||||
_plot(axes[1, 1], hard_brakes, "red", "Hard Brakes", "Hard Brakes Count")
|
||||
axes[1, 2].axis("off")
|
||||
axes[1, 2].text(
|
||||
0.1,
|
||||
0.5,
|
||||
summary,
|
||||
fontsize=12,
|
||||
family="monospace",
|
||||
verticalalignment="center",
|
||||
transform=axes[1, 2].transAxes,
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||||
|
|
|
|||
Loading…
Reference in New Issue