ctm-dqn/scripts/plot_live_training.py

621 lines
20 KiB
Python

"""Plot live training snapshots from CSV logs."""
import argparse
import os
import sys
import time
from typing import Dict, List, Optional
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
import matplotlib
import numpy as np
import pandas as pd
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "ddpg", "sac", "td3", "sctd3"]
MODEL_LABELS = {
"ppo": "PPO",
"gpro": "GPRO-PPO",
"appo": "APPO",
"mappo": "MAPPO",
"tcamappo": "TCA-MAPPO",
"dcmappo": "DC-MAPPO",
"dqn": "DQN",
"ddpg": "DDPG",
"sac": "SAC",
"td3": "TD3",
"sctd3": "SC-TD3",
}
MODEL_COLORS = {
"ppo": "#1f77b4",
"gpro": "#6a3d9a",
"appo": "#ff7f0e",
"mappo": "#2ca02c",
"tcamappo": "#7f7f7f",
"dcmappo": "#8c564b",
"dqn": "#d62728",
"ddpg": "#9467bd",
"sac": "#e377c2",
"td3": "#17becf",
"sctd3": "#bcbd22",
}
def parse_args():
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/ddpg/sac/td3/sctd3")
parser.add_argument(
"--all-models",
action="store_true",
help="Plot all models into one overview figure.",
)
parser.add_argument(
"--run",
default=None,
help="Run timestamp. Default: latest under runs/<timestamp>/logs.",
)
parser.add_argument(
"--log-root",
default=None,
help="Log root directory or run root. Default: latest runs/<timestamp>/logs.",
)
parser.add_argument(
"--csv-path",
default=None,
help="Direct training CSV path override.",
)
parser.add_argument(
"--output",
default=None,
help="Detailed snapshot output path. Default: <model_dir>/training_snapshot.png",
)
parser.add_argument(
"--compare-output",
default=None,
help="Run comparison output path. Default: <model_dir>/run_comparison_snapshot.png",
)
parser.add_argument(
"--overview-output",
default=None,
help="All-model overview output path. Default: <run_dir>/all_models_training_snapshot.png",
)
parser.add_argument(
"--window",
type=int,
default=20,
help="Moving average window.",
)
parser.add_argument(
"--show-ma",
action="store_true",
help="Overlay moving-average curves.",
)
parser.add_argument(
"--watch",
action="store_true",
help="Continuously refresh plots.",
)
parser.add_argument(
"--interval",
type=int,
default=30,
help="Refresh interval in seconds when --watch is enabled.",
)
args = parser.parse_args()
if not args.all_models and not args.model:
parser.error("Either --model or --all-models must be specified.")
return args
def normalize_model_name(name: str) -> str:
model = name.strip().lower()
if model not in MODEL_LABELS:
raise ValueError(f"Unsupported model name: {name}")
return model
def find_latest_run(log_root: str) -> str:
if not os.path.isdir(log_root):
raise FileNotFoundError(f"Log root not found: {log_root}")
candidates = [
item for item in os.listdir(log_root)
if os.path.isdir(os.path.join(log_root, item))
]
if not candidates:
raise FileNotFoundError(f"No run directories found under: {log_root}")
return max(candidates)
def resolve_run_logs_dir(log_root: Optional[str], run_name: Optional[str]) -> tuple[str, str]:
if log_root is None:
run_root = find_run_root_by_timestamp(run_name) if run_name else find_latest_run_root()
resolved_log_root = (
os.path.join(run_root, "logs")
if os.path.isdir(os.path.join(run_root, "logs"))
else run_root
)
else:
resolved_log_root = os.path.abspath(log_root)
if os.path.basename(resolved_log_root).lower() == "logs":
inferred_run_name = os.path.basename(os.path.dirname(resolved_log_root))
return inferred_run_name, resolved_log_root
if os.path.isdir(os.path.join(resolved_log_root, "logs")):
inferred_run_name = os.path.basename(resolved_log_root)
return inferred_run_name, os.path.join(resolved_log_root, "logs")
chosen_run_name = run_name or find_latest_run(resolved_log_root)
candidate_run_dir = os.path.join(resolved_log_root, chosen_run_name)
candidate_logs_dir = os.path.join(candidate_run_dir, "logs")
if os.path.isdir(candidate_logs_dir):
return chosen_run_name, candidate_logs_dir
if os.path.isdir(candidate_run_dir):
return chosen_run_name, candidate_run_dir
raise FileNotFoundError(f"Run directory not found: {candidate_run_dir}")
def resolve_paths(args, model_name: str, run_name: str, run_logs_dir: str):
if args.csv_path:
csv_path = args.csv_path
model_dir = os.path.dirname(os.path.abspath(csv_path))
else:
model_dir = os.path.join(run_logs_dir, model_name)
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
if not os.path.isfile(csv_path):
raise FileNotFoundError(f"Training CSV not found: {csv_path}")
output_path = args.output or os.path.join(model_dir, "training_snapshot.png")
compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png")
return run_name, model_dir, csv_path, output_path, compare_output
def resolve_run_dir(args, run_name: str, run_logs_dir: str):
run_dir = (
os.path.dirname(run_logs_dir)
if os.path.basename(run_logs_dir).lower() == "logs"
else run_logs_dir
)
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
return run_name, run_dir, overview_output
def safe_read_csv(csv_path: str) -> pd.DataFrame:
df = pd.read_csv(csv_path, on_bad_lines="skip")
if df.empty:
raise ValueError(f"CSV exists but contains no rows: {csv_path}")
numeric_columns = [
"episode", "reward", "throughput", "mean_speed", "speed_std",
"r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes",
"policy_loss", "value_loss", "entropy",
]
for column in numeric_columns:
if column in df.columns:
df[column] = pd.to_numeric(df[column], errors="coerce")
df = df.dropna(subset=["episode"])
df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last")
return df.reset_index(drop=True)
def moving_average(values: pd.Series, window: int) -> pd.Series:
return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean()
def plot_series(
ax,
df: pd.DataFrame,
column: str,
title: str,
ylabel: str,
color: str,
window: int,
show_ma: bool,
):
if column not in df.columns:
ax.axis("off")
return
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
ax.axis("off")
return
raw_label = "Series" if show_ma else None
ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
color=color,
alpha=0.35,
linewidth=1.0,
linestyle="--",
label=f"MA{window}",
)
ax.set_title(title)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
if show_ma:
ax.legend(loc="best")
def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str:
last_row = df.iloc[-1]
reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan
tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan
speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan
brake_last20 = df["hard_brakes"].tail(min(20, len(df))).mean() if "hard_brakes" in df else np.nan
return (
f"Run: {run_name}\n"
f"Model: {MODEL_LABELS[model_name]}\n"
f"Episodes logged: {len(df)}\n"
f"Latest episode: {int(last_row['episode'])}\n"
f"Best reward: {df['reward'].max():.2f}\n"
f"Latest reward: {last_row.get('reward', np.nan):.2f}\n"
f"Reward last20: {reward_last20:.2f}\n"
f"Throughput last20: {tp_last20:.1f}\n"
f"Mean speed last20: {speed_last20:.1f}\n"
f"Hard brakes last20: {brake_last20:.1f}\n"
f"\nCSV:\n{os.path.abspath(csv_path)}"
)
def plot_detailed_snapshot(
model_name: str,
df: pd.DataFrame,
run_name: str,
csv_path: str,
output_path: str,
window: int,
show_ma: bool,
):
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
axes = axes.flatten()
plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma)
plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma)
plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma)
plot_series(axes[3], df, "speed_std", "Speed Std", "km/h", "tab:purple", window, show_ma)
plot_series(axes[4], df, "hard_brakes", "Hard Brakes", "count", "tab:red", window, show_ma)
reward_components = ["r_flow", "r_var", "r_brake", "r_penalty"]
has_components = any(
col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0
for col in reward_components
)
if has_components:
for col, color in zip(reward_components, ["tab:green", "tab:purple", "tab:red", "tab:brown"]):
series = pd.to_numeric(df[col], errors="coerce")
if series.notna().sum() > 0:
axes[5].plot(
df["episode"],
series,
linewidth=1.6,
alpha=0.75,
color=color,
label=col,
)
if show_ma:
axes[5].plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
alpha=0.35,
linestyle="--",
color=color,
label="_nolegend_",
)
axes[5].set_title("Reward Components")
axes[5].set_xlabel("Episode")
axes[5].grid(True, alpha=0.3)
axes[5].legend(loc="best")
else:
axes[5].axis("off")
plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
axes[8].axis("off")
axes[8].text(
0.0,
1.0,
build_summary_text(model_name, df, run_name, csv_path),
va="top",
ha="left",
family="monospace",
fontsize=11,
transform=axes[8].transAxes,
)
plt.tight_layout()
plt.savefig(output_path, dpi=160, bbox_inches="tight")
plt.close()
def load_run_model_logs(run_logs_dir: str) -> Dict[str, pd.DataFrame]:
run_logs = {}
for model_name in MODEL_ORDER:
csv_path = os.path.join(run_logs_dir, model_name, f"{model_name}_training_log.csv")
if os.path.isfile(csv_path):
try:
run_logs[model_name] = safe_read_csv(csv_path)
except Exception:
continue
return run_logs
def plot_waiting_placeholder(output_path: str, run_name: str, message: str):
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis("off")
ax.text(
0.5,
0.6,
f"Run: {run_name}",
ha="center",
va="center",
fontsize=16,
family="monospace",
transform=ax.transAxes,
)
ax.text(
0.5,
0.4,
message,
ha="center",
va="center",
fontsize=12,
transform=ax.transAxes,
)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def plot_run_comparison(
target_model: str,
run_logs: Dict[str, pd.DataFrame],
compare_output: str,
window: int,
show_ma: bool,
):
if len(run_logs) <= 1:
return
metrics = [
("reward", "Reward"),
("throughput", "Throughput"),
("mean_speed", "Mean Speed"),
("speed_std", "Speed Std"),
]
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()
for ax, (column, title) in zip(axes, metrics):
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or column not in df.columns:
continue
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
continue
linewidth = 2.8 if model_name == target_model else 1.6
alpha = 1.0 if model_name == target_model else 0.5
ax.plot(
df["episode"],
series,
linewidth=linewidth + 0.4,
alpha=alpha,
color=MODEL_COLORS[model_name],
label=MODEL_LABELS[model_name],
)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
linestyle="--",
alpha=min(alpha, 0.4),
color=MODEL_COLORS[model_name],
label="_nolegend_",
)
ax.set_title(title)
ax.set_xlabel("Episode")
ax.grid(True, alpha=0.3)
axes[0].legend(loc="best")
plt.tight_layout()
plt.savefig(compare_output, dpi=160, bbox_inches="tight")
plt.close()
def plot_metric_overlay(
ax,
run_logs: Dict[str, pd.DataFrame],
column: str,
title: str,
ylabel: str,
window: int,
show_ma: bool,
):
plotted = False
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or column not in df.columns:
continue
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
continue
ax.plot(
df["episode"],
series,
alpha=0.85,
linewidth=1.8,
color=MODEL_COLORS[model_name],
label=MODEL_LABELS[model_name],
)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
linestyle="--",
alpha=0.35,
color=MODEL_COLORS[model_name],
label="_nolegend_",
)
plotted = True
if not plotted:
ax.axis("off")
return
ax.set_title(title)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
ax.legend(loc="best")
def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str:
lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""]
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or df.empty:
continue
reward = pd.to_numeric(df.get("reward"), errors="coerce")
throughput = pd.to_numeric(df.get("throughput"), errors="coerce")
latest_episode = int(df["episode"].iloc[-1])
reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan
tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan
lines.append(
f"{MODEL_LABELS[model_name]}: ep={latest_episode}, "
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}"
)
return "\n".join(lines)
def plot_all_models_overview(
run_name: str,
run_logs: Dict[str, pd.DataFrame],
output_path: str,
window: int,
show_ma: bool,
):
if not run_logs:
raise ValueError("No readable model logs found for overview plotting.")
fig, axes = plt.subplots(4, 3, figsize=(22, 16))
axes = axes.flatten()
metrics = [
("reward", "Reward", "Reward"),
("throughput", "Throughput", "veh/h"),
("mean_speed", "Mean Speed", "km/h"),
("speed_std", "Speed Std", "km/h"),
("hard_brakes", "Hard Brakes", "count"),
("r_flow", "R_flow", "value"),
("r_var", "R_var", "value"),
("r_brake", "R_brake", "value"),
("r_penalty", "R_penalty", "value"),
("policy_loss", "Policy Loss", "loss"),
("value_loss", "Value Loss", "loss"),
("entropy", "Entropy", "value"),
]
for ax, (column, title, ylabel) in zip(axes, metrics):
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.savefig(output_path, dpi=170, bbox_inches="tight")
plt.close()
summary_path = os.path.splitext(output_path)[0] + "_summary.txt"
with open(summary_path, "w", encoding="utf-8") as f:
f.write(build_overview_summary(run_name, run_logs))
def run_once(args):
run_name, run_logs_dir = resolve_run_logs_dir(args.log_root, args.run)
if args.all_models:
_, _, overview_output = resolve_run_dir(args, run_name, run_logs_dir)
run_logs = load_run_model_logs(run_logs_dir)
if not run_logs:
plot_waiting_placeholder(
overview_output,
run_name,
"Waiting for training logs to accumulate...",
)
print(
f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | "
f"output={overview_output}"
)
return
plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma)
print(
f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | "
f"models={len(run_logs)} | output={overview_output}"
)
return
model_name = normalize_model_name(args.model)
run_name, _, csv_path, output_path, compare_output = resolve_paths(
args, model_name, run_name, run_logs_dir
)
df = safe_read_csv(csv_path)
plot_detailed_snapshot(
model_name=model_name,
df=df,
run_name=run_name,
csv_path=csv_path,
output_path=output_path,
window=args.window,
show_ma=args.show_ma,
)
run_logs = load_run_model_logs(run_logs_dir)
plot_run_comparison(
target_model=model_name,
run_logs=run_logs,
compare_output=compare_output,
window=args.window,
show_ma=args.show_ma,
)
print(
f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | "
f"episodes={len(df)} | output={output_path}"
)
if len(run_logs) > 1:
print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}")
def main():
args = parse_args()
if args.watch:
try:
while True:
try:
run_once(args)
except Exception as exc:
print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}")
time.sleep(max(args.interval, 1))
except KeyboardInterrupt:
print("Stopped live plotting.")
else:
run_once(args)
if __name__ == "__main__":
main()