From 1b9f90b5bb8ffbccf0bdb97f13401a78a278579f Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Thu, 9 Apr 2026 03:36:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=BF=90=E8=A1=8C=E6=97=B6?= =?UTF-8?q?=E7=BB=98=E5=88=B6=E5=85=A8=E9=83=A8=E5=AF=B9=E6=AF=94=E5=9B=BE?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run_all_training.py | 49 ++- scripts/plot_live_training.py | 566 ++++++++++++++++++++++++++++++++++ 2 files changed, 607 insertions(+), 8 deletions(-) create mode 100644 scripts/plot_live_training.py diff --git a/run_all_training.py b/run_all_training.py index b14ff58..73e3f38 100644 --- a/run_all_training.py +++ b/run_all_training.py @@ -1,20 +1,25 @@ -"""一键异步启动全部训练。""" +"""Launch all model training jobs plus a live overview plotter.""" import os import subprocess import sys from datetime import datetime AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] +PLOT_REFRESH_INTERVAL = 30 def main(): run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - processes = {} + run_log_root = os.path.join("logs", "multi-model", run_timestamp) + run_ckpt_root = os.path.join("checkpoints", "multi-model", run_timestamp) + os.makedirs(run_log_root, exist_ok=True) + os.makedirs(run_ckpt_root, exist_ok=True) + processes = {} for agent in AGENTS: print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...") - log_dir = os.path.join("logs", "multi-model", run_timestamp, agent) - checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent) + log_dir = os.path.join(run_log_root, agent) + checkpoint_dir = os.path.join(run_ckpt_root, agent) os.makedirs(log_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) stdout_path = os.path.join(log_dir, "stdout.txt") @@ -37,13 +42,41 @@ def main(): processes[agent] = process print(f" PID: {process.pid}") + plotter_stdout_path = os.path.join(run_log_root, "plotter_stdout.txt") + plotter_process = subprocess.Popen( + [ + sys.executable, + "-m", + "scripts.plot_live_training", + "--all-models", + "--run", + run_timestamp, + "--watch", + "--interval", + str(PLOT_REFRESH_INTERVAL), + ], + stdout=open(plotter_stdout_path, "w", encoding="utf-8"), + stderr=subprocess.STDOUT, + ) + print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动联训总览绘图...") + print(f" PID: {plotter_process.pid}") + print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...") print(f"本次多模型时间戳: {run_timestamp}\n") - for agent, process in processes.items(): - process.wait() - status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})" - print(f"[{agent.upper()}] {status}") + try: + for agent, process in processes.items(): + process.wait() + status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})" + print(f"[{agent.upper()}] {status}") + finally: + if plotter_process.poll() is None: + plotter_process.terminate() + try: + plotter_process.wait(timeout=10) + except subprocess.TimeoutExpired: + plotter_process.kill() + print("[PLOTTER] 已停止自动绘图进程") if __name__ == "__main__": diff --git a/scripts/plot_live_training.py b/scripts/plot_live_training.py new file mode 100644 index 0000000..38b2afe --- /dev/null +++ b/scripts/plot_live_training.py @@ -0,0 +1,566 @@ +"""Plot live training snapshots from CSV logs.""" +import argparse +import os +import time +from typing import Dict, List, Optional + +import matplotlib +import numpy as np +import pandas as pd + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +MODEL_ORDER = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] +MODEL_LABELS = {name: name.upper() for name in MODEL_ORDER} +MODEL_COLORS = { + "ppo": "#1f77b4", + "appo": "#ff7f0e", + "mappo": "#2ca02c", + "dqn": "#d62728", + "ddpg": "#9467bd", + "td3": "#17becf", +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.") + parser.add_argument("--model", default=None, help="Model name, e.g. ppo/appo/mappo/dqn/ddpg/td3") + parser.add_argument( + "--all-models", + action="store_true", + help="Plot all models into one overview figure.", + ) + parser.add_argument( + "--run", + default=None, + help="Multi-model run timestamp. Default: latest under logs/multi-model/", + ) + parser.add_argument( + "--log-root", + default=os.path.join("logs", "multi-model"), + help="Multi-model log root directory.", + ) + parser.add_argument( + "--csv-path", + default=None, + help="Direct training CSV path override.", + ) + parser.add_argument( + "--output", + default=None, + help="Detailed snapshot output path. Default: /training_snapshot.png", + ) + parser.add_argument( + "--compare-output", + default=None, + help="Run comparison output path. Default: /run_comparison_snapshot.png", + ) + parser.add_argument( + "--overview-output", + default=None, + help="All-model overview output path. Default: /all_models_training_snapshot.png", + ) + parser.add_argument( + "--window", + type=int, + default=20, + help="Moving average window.", + ) + parser.add_argument( + "--show-ma", + action="store_true", + help="Overlay moving-average curves.", + ) + parser.add_argument( + "--watch", + action="store_true", + help="Continuously refresh plots.", + ) + parser.add_argument( + "--interval", + type=int, + default=30, + help="Refresh interval in seconds when --watch is enabled.", + ) + args = parser.parse_args() + if not args.all_models and not args.model: + parser.error("Either --model or --all-models must be specified.") + return args + + +def normalize_model_name(name: str) -> str: + model = name.strip().lower() + if model not in MODEL_LABELS: + raise ValueError(f"Unsupported model name: {name}") + return model + + +def find_latest_run(log_root: str) -> str: + if not os.path.isdir(log_root): + raise FileNotFoundError(f"Log root not found: {log_root}") + candidates = [ + item for item in os.listdir(log_root) + if os.path.isdir(os.path.join(log_root, item)) + ] + if not candidates: + raise FileNotFoundError(f"No run directories found under: {log_root}") + return max(candidates) + + +def resolve_paths(args, model_name: str): + if args.csv_path: + csv_path = args.csv_path + model_dir = os.path.dirname(os.path.abspath(csv_path)) + run_dir = os.path.dirname(model_dir) + run_name = os.path.basename(run_dir) + else: + run_name = args.run or find_latest_run(args.log_root) + model_dir = os.path.join(args.log_root, run_name, model_name) + csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv") + + if not os.path.isfile(csv_path): + raise FileNotFoundError(f"Training CSV not found: {csv_path}") + + output_path = args.output or os.path.join(model_dir, "training_snapshot.png") + compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png") + return run_name, model_dir, csv_path, output_path, compare_output + + +def resolve_run_dir(args): + run_name = args.run or find_latest_run(args.log_root) + run_dir = os.path.join(args.log_root, run_name) + if not os.path.isdir(run_dir): + raise FileNotFoundError(f"Run directory not found: {run_dir}") + overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png") + return run_name, run_dir, overview_output + + +def safe_read_csv(csv_path: str) -> pd.DataFrame: + df = pd.read_csv(csv_path, on_bad_lines="skip") + if df.empty: + raise ValueError(f"CSV exists but contains no rows: {csv_path}") + + numeric_columns = [ + "episode", "reward", "throughput", "mean_speed", "speed_std", + "r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes", + "policy_loss", "value_loss", "entropy", + ] + for column in numeric_columns: + if column in df.columns: + df[column] = pd.to_numeric(df[column], errors="coerce") + + df = df.dropna(subset=["episode"]) + df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last") + return df.reset_index(drop=True) + + +def moving_average(values: pd.Series, window: int) -> pd.Series: + return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean() + + +def plot_series( + ax, + df: pd.DataFrame, + column: str, + title: str, + ylabel: str, + color: str, + window: int, + show_ma: bool, +): + if column not in df.columns: + ax.axis("off") + return + + series = pd.to_numeric(df[column], errors="coerce") + if series.notna().sum() == 0: + ax.axis("off") + return + + raw_label = "Series" if show_ma else None + ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label) + if show_ma: + ax.plot( + df["episode"], + moving_average(series, window), + color=color, + alpha=0.35, + linewidth=1.0, + linestyle="--", + label=f"MA{window}", + ) + ax.set_title(title) + ax.set_xlabel("Episode") + ax.set_ylabel(ylabel) + ax.grid(True, alpha=0.3) + if show_ma: + ax.legend(loc="best") + + +def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str: + last_row = df.iloc[-1] + reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan + tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan + speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan + brake_last20 = df["hard_brakes"].tail(min(20, len(df))).mean() if "hard_brakes" in df else np.nan + + return ( + f"Run: {run_name}\n" + f"Model: {MODEL_LABELS[model_name]}\n" + f"Episodes logged: {len(df)}\n" + f"Latest episode: {int(last_row['episode'])}\n" + f"Best reward: {df['reward'].max():.2f}\n" + f"Latest reward: {last_row.get('reward', np.nan):.2f}\n" + f"Reward last20: {reward_last20:.2f}\n" + f"Throughput last20: {tp_last20:.1f}\n" + f"Mean speed last20: {speed_last20:.1f}\n" + f"Hard brakes last20: {brake_last20:.1f}\n" + f"\nCSV:\n{os.path.abspath(csv_path)}" + ) + + +def plot_detailed_snapshot( + model_name: str, + df: pd.DataFrame, + run_name: str, + csv_path: str, + output_path: str, + window: int, + show_ma: bool, +): + fig, axes = plt.subplots(3, 3, figsize=(18, 12)) + axes = axes.flatten() + + plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma) + plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma) + plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma) + plot_series(axes[3], df, "speed_std", "Speed Std", "km/h", "tab:purple", window, show_ma) + plot_series(axes[4], df, "hard_brakes", "Hard Brakes", "count", "tab:red", window, show_ma) + + reward_components = ["r_flow", "r_var", "r_brake", "r_penalty"] + has_components = any( + col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0 + for col in reward_components + ) + if has_components: + for col, color in zip(reward_components, ["tab:green", "tab:purple", "tab:red", "tab:brown"]): + series = pd.to_numeric(df[col], errors="coerce") + if series.notna().sum() > 0: + axes[5].plot( + df["episode"], + series, + linewidth=1.6, + alpha=0.75, + color=color, + label=col, + ) + if show_ma: + axes[5].plot( + df["episode"], + moving_average(series, window), + linewidth=1.0, + alpha=0.35, + linestyle="--", + color=color, + label="_nolegend_", + ) + axes[5].set_title("Reward Components") + axes[5].set_xlabel("Episode") + axes[5].grid(True, alpha=0.3) + axes[5].legend(loc="best") + else: + axes[5].axis("off") + + plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma) + plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma) + + axes[8].axis("off") + axes[8].text( + 0.0, + 1.0, + build_summary_text(model_name, df, run_name, csv_path), + va="top", + ha="left", + family="monospace", + fontsize=11, + transform=axes[8].transAxes, + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=160, bbox_inches="tight") + plt.close() + + +def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]: + run_logs = {} + for model_name in MODEL_ORDER: + csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv") + if os.path.isfile(csv_path): + try: + run_logs[model_name] = safe_read_csv(csv_path) + except Exception: + continue + return run_logs + + +def plot_waiting_placeholder(output_path: str, run_name: str, message: str): + fig, ax = plt.subplots(figsize=(10, 4)) + ax.axis("off") + ax.text( + 0.5, + 0.6, + f"Run: {run_name}", + ha="center", + va="center", + fontsize=16, + family="monospace", + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.4, + message, + ha="center", + va="center", + fontsize=12, + transform=ax.transAxes, + ) + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + +def plot_run_comparison( + target_model: str, + run_logs: Dict[str, pd.DataFrame], + compare_output: str, + window: int, + show_ma: bool, +): + if len(run_logs) <= 1: + return + + metrics = [ + ("reward", "Reward"), + ("throughput", "Throughput"), + ("mean_speed", "Mean Speed"), + ("speed_std", "Speed Std"), + ] + fig, axes = plt.subplots(2, 2, figsize=(16, 10)) + axes = axes.flatten() + + for ax, (column, title) in zip(axes, metrics): + for model_name in MODEL_ORDER: + df = run_logs.get(model_name) + if df is None or column not in df.columns: + continue + series = pd.to_numeric(df[column], errors="coerce") + if series.notna().sum() == 0: + continue + linewidth = 2.8 if model_name == target_model else 1.6 + alpha = 1.0 if model_name == target_model else 0.5 + ax.plot( + df["episode"], + series, + linewidth=linewidth + 0.4, + alpha=alpha, + color=MODEL_COLORS[model_name], + label=MODEL_LABELS[model_name], + ) + if show_ma: + ax.plot( + df["episode"], + moving_average(series, window), + linewidth=1.0, + linestyle="--", + alpha=min(alpha, 0.4), + color=MODEL_COLORS[model_name], + label="_nolegend_", + ) + ax.set_title(title) + ax.set_xlabel("Episode") + ax.grid(True, alpha=0.3) + axes[0].legend(loc="best") + plt.tight_layout() + plt.savefig(compare_output, dpi=160, bbox_inches="tight") + plt.close() + + +def plot_metric_overlay( + ax, + run_logs: Dict[str, pd.DataFrame], + column: str, + title: str, + ylabel: str, + window: int, + show_ma: bool, +): + plotted = False + for model_name in MODEL_ORDER: + df = run_logs.get(model_name) + if df is None or column not in df.columns: + continue + series = pd.to_numeric(df[column], errors="coerce") + if series.notna().sum() == 0: + continue + ax.plot( + df["episode"], + series, + alpha=0.85, + linewidth=1.8, + color=MODEL_COLORS[model_name], + label=MODEL_LABELS[model_name], + ) + if show_ma: + ax.plot( + df["episode"], + moving_average(series, window), + linewidth=1.0, + linestyle="--", + alpha=0.35, + color=MODEL_COLORS[model_name], + label="_nolegend_", + ) + plotted = True + + if not plotted: + ax.axis("off") + return + + ax.set_title(title) + ax.set_xlabel("Episode") + ax.set_ylabel(ylabel) + ax.grid(True, alpha=0.3) + ax.legend(loc="best") + + +def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str: + lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""] + for model_name in MODEL_ORDER: + df = run_logs.get(model_name) + if df is None or df.empty: + continue + reward = pd.to_numeric(df.get("reward"), errors="coerce") + throughput = pd.to_numeric(df.get("throughput"), errors="coerce") + latest_episode = int(df["episode"].iloc[-1]) + reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan + tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan + lines.append( + f"{MODEL_LABELS[model_name]}: ep={latest_episode}, " + f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}" + ) + return "\n".join(lines) + + +def plot_all_models_overview( + run_name: str, + run_logs: Dict[str, pd.DataFrame], + output_path: str, + window: int, + show_ma: bool, +): + if not run_logs: + raise ValueError("No readable model logs found for overview plotting.") + + fig, axes = plt.subplots(4, 3, figsize=(22, 16)) + axes = axes.flatten() + metrics = [ + ("reward", "Reward", "Reward"), + ("throughput", "Throughput", "veh/h"), + ("mean_speed", "Mean Speed", "km/h"), + ("speed_std", "Speed Std", "km/h"), + ("hard_brakes", "Hard Brakes", "count"), + ("r_flow", "R_flow", "value"), + ("r_var", "R_var", "value"), + ("r_brake", "R_brake", "value"), + ("r_penalty", "R_penalty", "value"), + ("policy_loss", "Policy Loss", "loss"), + ("value_loss", "Value Loss", "loss"), + ("entropy", "Entropy", "value"), + ] + + for ax, (column, title, ylabel) in zip(axes, metrics): + plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma) + + fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995) + plt.tight_layout(rect=[0, 0, 1, 0.98]) + plt.savefig(output_path, dpi=170, bbox_inches="tight") + plt.close() + + summary_path = os.path.splitext(output_path)[0] + "_summary.txt" + with open(summary_path, "w", encoding="utf-8") as f: + f.write(build_overview_summary(run_name, run_logs)) + + +def run_once(args): + if args.all_models: + run_name, _, overview_output = resolve_run_dir(args) + run_logs = load_run_model_logs(args.log_root, run_name) + if not run_logs: + plot_waiting_placeholder( + overview_output, + run_name, + "Waiting for training logs to accumulate...", + ) + print( + f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | " + f"output={overview_output}" + ) + return + plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma) + print( + f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | " + f"models={len(run_logs)} | output={overview_output}" + ) + return + + model_name = normalize_model_name(args.model) + run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name) + df = safe_read_csv(csv_path) + plot_detailed_snapshot( + model_name=model_name, + df=df, + run_name=run_name, + csv_path=csv_path, + output_path=output_path, + window=args.window, + show_ma=args.show_ma, + ) + + run_logs = load_run_model_logs(args.log_root, run_name) + plot_run_comparison( + target_model=model_name, + run_logs=run_logs, + compare_output=compare_output, + window=args.window, + show_ma=args.show_ma, + ) + + print( + f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | " + f"episodes={len(df)} | output={output_path}" + ) + if len(run_logs) > 1: + print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}") + + +def main(): + args = parse_args() + if args.watch: + try: + while True: + try: + run_once(args) + except Exception as exc: + print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}") + time.sleep(max(args.interval, 1)) + except KeyboardInterrupt: + print("Stopped live plotting.") + else: + run_once(args) + + +if __name__ == "__main__": + main()