"""Plot live training snapshots from CSV logs.""" import argparse import os import time from typing import Dict, List, Optional import matplotlib import numpy as np import pandas as pd matplotlib.use("Agg") import matplotlib.pyplot as plt MODEL_ORDER = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"] MODEL_LABELS = {name: name.upper() for name in MODEL_ORDER} MODEL_COLORS = { "ppo": "#1f77b4", "appo": "#ff7f0e", "mappo": "#2ca02c", "dqn": "#d62728", "ddpg": "#9467bd", "td3": "#17becf", } def parse_args(): parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.") parser.add_argument("--model", default=None, help="Model name, e.g. ppo/appo/mappo/dqn/ddpg/td3") parser.add_argument( "--all-models", action="store_true", help="Plot all models into one overview figure.", ) parser.add_argument( "--run", default=None, help="Multi-model run timestamp. Default: latest under logs/multi-model/", ) parser.add_argument( "--log-root", default=os.path.join("logs", "multi-model"), help="Multi-model log root directory.", ) parser.add_argument( "--csv-path", default=None, help="Direct training CSV path override.", ) parser.add_argument( "--output", default=None, help="Detailed snapshot output path. Default: /training_snapshot.png", ) parser.add_argument( "--compare-output", default=None, help="Run comparison output path. Default: /run_comparison_snapshot.png", ) parser.add_argument( "--overview-output", default=None, help="All-model overview output path. Default: /all_models_training_snapshot.png", ) parser.add_argument( "--window", type=int, default=20, help="Moving average window.", ) parser.add_argument( "--show-ma", action="store_true", help="Overlay moving-average curves.", ) parser.add_argument( "--watch", action="store_true", help="Continuously refresh plots.", ) parser.add_argument( "--interval", type=int, default=30, help="Refresh interval in seconds when --watch is enabled.", ) args = parser.parse_args() if not args.all_models and not args.model: parser.error("Either --model or --all-models must be specified.") return args def normalize_model_name(name: str) -> str: model = name.strip().lower() if model not in MODEL_LABELS: raise ValueError(f"Unsupported model name: {name}") return model def find_latest_run(log_root: str) -> str: if not os.path.isdir(log_root): raise FileNotFoundError(f"Log root not found: {log_root}") candidates = [ item for item in os.listdir(log_root) if os.path.isdir(os.path.join(log_root, item)) ] if not candidates: raise FileNotFoundError(f"No run directories found under: {log_root}") return max(candidates) def resolve_paths(args, model_name: str): if args.csv_path: csv_path = args.csv_path model_dir = os.path.dirname(os.path.abspath(csv_path)) run_dir = os.path.dirname(model_dir) run_name = os.path.basename(run_dir) else: run_name = args.run or find_latest_run(args.log_root) model_dir = os.path.join(args.log_root, run_name, model_name) csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv") if not os.path.isfile(csv_path): raise FileNotFoundError(f"Training CSV not found: {csv_path}") output_path = args.output or os.path.join(model_dir, "training_snapshot.png") compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png") return run_name, model_dir, csv_path, output_path, compare_output def resolve_run_dir(args): run_name = args.run or find_latest_run(args.log_root) run_dir = os.path.join(args.log_root, run_name) if not os.path.isdir(run_dir): raise FileNotFoundError(f"Run directory not found: {run_dir}") overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png") return run_name, run_dir, overview_output def safe_read_csv(csv_path: str) -> pd.DataFrame: df = pd.read_csv(csv_path, on_bad_lines="skip") if df.empty: raise ValueError(f"CSV exists but contains no rows: {csv_path}") numeric_columns = [ "episode", "reward", "throughput", "mean_speed", "speed_std", "r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes", "policy_loss", "value_loss", "entropy", ] for column in numeric_columns: if column in df.columns: df[column] = pd.to_numeric(df[column], errors="coerce") df = df.dropna(subset=["episode"]) df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last") return df.reset_index(drop=True) def moving_average(values: pd.Series, window: int) -> pd.Series: return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean() def plot_series( ax, df: pd.DataFrame, column: str, title: str, ylabel: str, color: str, window: int, show_ma: bool, ): if column not in df.columns: ax.axis("off") return series = pd.to_numeric(df[column], errors="coerce") if series.notna().sum() == 0: ax.axis("off") return raw_label = "Series" if show_ma else None ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label) if show_ma: ax.plot( df["episode"], moving_average(series, window), color=color, alpha=0.35, linewidth=1.0, linestyle="--", label=f"MA{window}", ) ax.set_title(title) ax.set_xlabel("Episode") ax.set_ylabel(ylabel) ax.grid(True, alpha=0.3) if show_ma: ax.legend(loc="best") def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str: last_row = df.iloc[-1] reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan brake_last20 = df["hard_brakes"].tail(min(20, len(df))).mean() if "hard_brakes" in df else np.nan return ( f"Run: {run_name}\n" f"Model: {MODEL_LABELS[model_name]}\n" f"Episodes logged: {len(df)}\n" f"Latest episode: {int(last_row['episode'])}\n" f"Best reward: {df['reward'].max():.2f}\n" f"Latest reward: {last_row.get('reward', np.nan):.2f}\n" f"Reward last20: {reward_last20:.2f}\n" f"Throughput last20: {tp_last20:.1f}\n" f"Mean speed last20: {speed_last20:.1f}\n" f"Hard brakes last20: {brake_last20:.1f}\n" f"\nCSV:\n{os.path.abspath(csv_path)}" ) def plot_detailed_snapshot( model_name: str, df: pd.DataFrame, run_name: str, csv_path: str, output_path: str, window: int, show_ma: bool, ): fig, axes = plt.subplots(3, 3, figsize=(18, 12)) axes = axes.flatten() plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma) plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma) plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma) plot_series(axes[3], df, "speed_std", "Speed Std", "km/h", "tab:purple", window, show_ma) plot_series(axes[4], df, "hard_brakes", "Hard Brakes", "count", "tab:red", window, show_ma) reward_components = ["r_flow", "r_var", "r_brake", "r_penalty"] has_components = any( col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0 for col in reward_components ) if has_components: for col, color in zip(reward_components, ["tab:green", "tab:purple", "tab:red", "tab:brown"]): series = pd.to_numeric(df[col], errors="coerce") if series.notna().sum() > 0: axes[5].plot( df["episode"], series, linewidth=1.6, alpha=0.75, color=color, label=col, ) if show_ma: axes[5].plot( df["episode"], moving_average(series, window), linewidth=1.0, alpha=0.35, linestyle="--", color=color, label="_nolegend_", ) axes[5].set_title("Reward Components") axes[5].set_xlabel("Episode") axes[5].grid(True, alpha=0.3) axes[5].legend(loc="best") else: axes[5].axis("off") plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma) plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma) axes[8].axis("off") axes[8].text( 0.0, 1.0, build_summary_text(model_name, df, run_name, csv_path), va="top", ha="left", family="monospace", fontsize=11, transform=axes[8].transAxes, ) plt.tight_layout() plt.savefig(output_path, dpi=160, bbox_inches="tight") plt.close() def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]: run_logs = {} for model_name in MODEL_ORDER: csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv") if os.path.isfile(csv_path): try: run_logs[model_name] = safe_read_csv(csv_path) except Exception: continue return run_logs def plot_waiting_placeholder(output_path: str, run_name: str, message: str): fig, ax = plt.subplots(figsize=(10, 4)) ax.axis("off") ax.text( 0.5, 0.6, f"Run: {run_name}", ha="center", va="center", fontsize=16, family="monospace", transform=ax.transAxes, ) ax.text( 0.5, 0.4, message, ha="center", va="center", fontsize=12, transform=ax.transAxes, ) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def plot_run_comparison( target_model: str, run_logs: Dict[str, pd.DataFrame], compare_output: str, window: int, show_ma: bool, ): if len(run_logs) <= 1: return metrics = [ ("reward", "Reward"), ("throughput", "Throughput"), ("mean_speed", "Mean Speed"), ("speed_std", "Speed Std"), ] fig, axes = plt.subplots(2, 2, figsize=(16, 10)) axes = axes.flatten() for ax, (column, title) in zip(axes, metrics): for model_name in MODEL_ORDER: df = run_logs.get(model_name) if df is None or column not in df.columns: continue series = pd.to_numeric(df[column], errors="coerce") if series.notna().sum() == 0: continue linewidth = 2.8 if model_name == target_model else 1.6 alpha = 1.0 if model_name == target_model else 0.5 ax.plot( df["episode"], series, linewidth=linewidth + 0.4, alpha=alpha, color=MODEL_COLORS[model_name], label=MODEL_LABELS[model_name], ) if show_ma: ax.plot( df["episode"], moving_average(series, window), linewidth=1.0, linestyle="--", alpha=min(alpha, 0.4), color=MODEL_COLORS[model_name], label="_nolegend_", ) ax.set_title(title) ax.set_xlabel("Episode") ax.grid(True, alpha=0.3) axes[0].legend(loc="best") plt.tight_layout() plt.savefig(compare_output, dpi=160, bbox_inches="tight") plt.close() def plot_metric_overlay( ax, run_logs: Dict[str, pd.DataFrame], column: str, title: str, ylabel: str, window: int, show_ma: bool, ): plotted = False for model_name in MODEL_ORDER: df = run_logs.get(model_name) if df is None or column not in df.columns: continue series = pd.to_numeric(df[column], errors="coerce") if series.notna().sum() == 0: continue ax.plot( df["episode"], series, alpha=0.85, linewidth=1.8, color=MODEL_COLORS[model_name], label=MODEL_LABELS[model_name], ) if show_ma: ax.plot( df["episode"], moving_average(series, window), linewidth=1.0, linestyle="--", alpha=0.35, color=MODEL_COLORS[model_name], label="_nolegend_", ) plotted = True if not plotted: ax.axis("off") return ax.set_title(title) ax.set_xlabel("Episode") ax.set_ylabel(ylabel) ax.grid(True, alpha=0.3) ax.legend(loc="best") def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str: lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""] for model_name in MODEL_ORDER: df = run_logs.get(model_name) if df is None or df.empty: continue reward = pd.to_numeric(df.get("reward"), errors="coerce") throughput = pd.to_numeric(df.get("throughput"), errors="coerce") latest_episode = int(df["episode"].iloc[-1]) reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan lines.append( f"{MODEL_LABELS[model_name]}: ep={latest_episode}, " f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}" ) return "\n".join(lines) def plot_all_models_overview( run_name: str, run_logs: Dict[str, pd.DataFrame], output_path: str, window: int, show_ma: bool, ): if not run_logs: raise ValueError("No readable model logs found for overview plotting.") fig, axes = plt.subplots(4, 3, figsize=(22, 16)) axes = axes.flatten() metrics = [ ("reward", "Reward", "Reward"), ("throughput", "Throughput", "veh/h"), ("mean_speed", "Mean Speed", "km/h"), ("speed_std", "Speed Std", "km/h"), ("hard_brakes", "Hard Brakes", "count"), ("r_flow", "R_flow", "value"), ("r_var", "R_var", "value"), ("r_brake", "R_brake", "value"), ("r_penalty", "R_penalty", "value"), ("policy_loss", "Policy Loss", "loss"), ("value_loss", "Value Loss", "loss"), ("entropy", "Entropy", "value"), ] for ax, (column, title, ylabel) in zip(axes, metrics): plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma) fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995) plt.tight_layout(rect=[0, 0, 1, 0.98]) plt.savefig(output_path, dpi=170, bbox_inches="tight") plt.close() summary_path = os.path.splitext(output_path)[0] + "_summary.txt" with open(summary_path, "w", encoding="utf-8") as f: f.write(build_overview_summary(run_name, run_logs)) def run_once(args): if args.all_models: run_name, _, overview_output = resolve_run_dir(args) run_logs = load_run_model_logs(args.log_root, run_name) if not run_logs: plot_waiting_placeholder( overview_output, run_name, "Waiting for training logs to accumulate...", ) print( f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | " f"output={overview_output}" ) return plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma) print( f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | " f"models={len(run_logs)} | output={overview_output}" ) return model_name = normalize_model_name(args.model) run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name) df = safe_read_csv(csv_path) plot_detailed_snapshot( model_name=model_name, df=df, run_name=run_name, csv_path=csv_path, output_path=output_path, window=args.window, show_ma=args.show_ma, ) run_logs = load_run_model_logs(args.log_root, run_name) plot_run_comparison( target_model=model_name, run_logs=run_logs, compare_output=compare_output, window=args.window, show_ma=args.show_ma, ) print( f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | " f"episodes={len(df)} | output={output_path}" ) if len(run_logs) > 1: print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}") def main(): args = parse_args() if args.watch: try: while True: try: run_once(args) except Exception as exc: print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}") time.sleep(max(args.interval, 1)) except KeyboardInterrupt: print("Stopped live plotting.") else: run_once(args) if __name__ == "__main__": main()