ctm-dqn/scripts/plot_live_training.py

567 lines
18 KiB
Python

"""Plot live training snapshots from CSV logs."""
import argparse
import os
import time
from typing import Dict, List, Optional
import matplotlib
import numpy as np
import pandas as pd
matplotlib.use("Agg")
import matplotlib.pyplot as plt
MODEL_ORDER = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
MODEL_LABELS = {name: name.upper() for name in MODEL_ORDER}
MODEL_COLORS = {
"ppo": "#1f77b4",
"appo": "#ff7f0e",
"mappo": "#2ca02c",
"dqn": "#d62728",
"ddpg": "#9467bd",
"td3": "#17becf",
}
def parse_args():
parser = argparse.ArgumentParser(description="Plot live training progress from multi-model logs.")
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/appo/mappo/dqn/ddpg/td3")
parser.add_argument(
"--all-models",
action="store_true",
help="Plot all models into one overview figure.",
)
parser.add_argument(
"--run",
default=None,
help="Multi-model run timestamp. Default: latest under logs/multi-model/",
)
parser.add_argument(
"--log-root",
default=os.path.join("logs", "multi-model"),
help="Multi-model log root directory.",
)
parser.add_argument(
"--csv-path",
default=None,
help="Direct training CSV path override.",
)
parser.add_argument(
"--output",
default=None,
help="Detailed snapshot output path. Default: <model_dir>/training_snapshot.png",
)
parser.add_argument(
"--compare-output",
default=None,
help="Run comparison output path. Default: <model_dir>/run_comparison_snapshot.png",
)
parser.add_argument(
"--overview-output",
default=None,
help="All-model overview output path. Default: <run_dir>/all_models_training_snapshot.png",
)
parser.add_argument(
"--window",
type=int,
default=20,
help="Moving average window.",
)
parser.add_argument(
"--show-ma",
action="store_true",
help="Overlay moving-average curves.",
)
parser.add_argument(
"--watch",
action="store_true",
help="Continuously refresh plots.",
)
parser.add_argument(
"--interval",
type=int,
default=30,
help="Refresh interval in seconds when --watch is enabled.",
)
args = parser.parse_args()
if not args.all_models and not args.model:
parser.error("Either --model or --all-models must be specified.")
return args
def normalize_model_name(name: str) -> str:
model = name.strip().lower()
if model not in MODEL_LABELS:
raise ValueError(f"Unsupported model name: {name}")
return model
def find_latest_run(log_root: str) -> str:
if not os.path.isdir(log_root):
raise FileNotFoundError(f"Log root not found: {log_root}")
candidates = [
item for item in os.listdir(log_root)
if os.path.isdir(os.path.join(log_root, item))
]
if not candidates:
raise FileNotFoundError(f"No run directories found under: {log_root}")
return max(candidates)
def resolve_paths(args, model_name: str):
if args.csv_path:
csv_path = args.csv_path
model_dir = os.path.dirname(os.path.abspath(csv_path))
run_dir = os.path.dirname(model_dir)
run_name = os.path.basename(run_dir)
else:
run_name = args.run or find_latest_run(args.log_root)
model_dir = os.path.join(args.log_root, run_name, model_name)
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
if not os.path.isfile(csv_path):
raise FileNotFoundError(f"Training CSV not found: {csv_path}")
output_path = args.output or os.path.join(model_dir, "training_snapshot.png")
compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png")
return run_name, model_dir, csv_path, output_path, compare_output
def resolve_run_dir(args):
run_name = args.run or find_latest_run(args.log_root)
run_dir = os.path.join(args.log_root, run_name)
if not os.path.isdir(run_dir):
raise FileNotFoundError(f"Run directory not found: {run_dir}")
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
return run_name, run_dir, overview_output
def safe_read_csv(csv_path: str) -> pd.DataFrame:
df = pd.read_csv(csv_path, on_bad_lines="skip")
if df.empty:
raise ValueError(f"CSV exists but contains no rows: {csv_path}")
numeric_columns = [
"episode", "reward", "throughput", "mean_speed", "speed_std",
"r_flow", "r_var", "r_brake", "r_penalty", "hard_brakes",
"policy_loss", "value_loss", "entropy",
]
for column in numeric_columns:
if column in df.columns:
df[column] = pd.to_numeric(df[column], errors="coerce")
df = df.dropna(subset=["episode"])
df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last")
return df.reset_index(drop=True)
def moving_average(values: pd.Series, window: int) -> pd.Series:
return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean()
def plot_series(
ax,
df: pd.DataFrame,
column: str,
title: str,
ylabel: str,
color: str,
window: int,
show_ma: bool,
):
if column not in df.columns:
ax.axis("off")
return
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
ax.axis("off")
return
raw_label = "Series" if show_ma else None
ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
color=color,
alpha=0.35,
linewidth=1.0,
linestyle="--",
label=f"MA{window}",
)
ax.set_title(title)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
if show_ma:
ax.legend(loc="best")
def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str:
last_row = df.iloc[-1]
reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan
tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan
speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan
brake_last20 = df["hard_brakes"].tail(min(20, len(df))).mean() if "hard_brakes" in df else np.nan
return (
f"Run: {run_name}\n"
f"Model: {MODEL_LABELS[model_name]}\n"
f"Episodes logged: {len(df)}\n"
f"Latest episode: {int(last_row['episode'])}\n"
f"Best reward: {df['reward'].max():.2f}\n"
f"Latest reward: {last_row.get('reward', np.nan):.2f}\n"
f"Reward last20: {reward_last20:.2f}\n"
f"Throughput last20: {tp_last20:.1f}\n"
f"Mean speed last20: {speed_last20:.1f}\n"
f"Hard brakes last20: {brake_last20:.1f}\n"
f"\nCSV:\n{os.path.abspath(csv_path)}"
)
def plot_detailed_snapshot(
model_name: str,
df: pd.DataFrame,
run_name: str,
csv_path: str,
output_path: str,
window: int,
show_ma: bool,
):
fig, axes = plt.subplots(3, 3, figsize=(18, 12))
axes = axes.flatten()
plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma)
plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma)
plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma)
plot_series(axes[3], df, "speed_std", "Speed Std", "km/h", "tab:purple", window, show_ma)
plot_series(axes[4], df, "hard_brakes", "Hard Brakes", "count", "tab:red", window, show_ma)
reward_components = ["r_flow", "r_var", "r_brake", "r_penalty"]
has_components = any(
col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0
for col in reward_components
)
if has_components:
for col, color in zip(reward_components, ["tab:green", "tab:purple", "tab:red", "tab:brown"]):
series = pd.to_numeric(df[col], errors="coerce")
if series.notna().sum() > 0:
axes[5].plot(
df["episode"],
series,
linewidth=1.6,
alpha=0.75,
color=color,
label=col,
)
if show_ma:
axes[5].plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
alpha=0.35,
linestyle="--",
color=color,
label="_nolegend_",
)
axes[5].set_title("Reward Components")
axes[5].set_xlabel("Episode")
axes[5].grid(True, alpha=0.3)
axes[5].legend(loc="best")
else:
axes[5].axis("off")
plot_series(axes[6], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
plot_series(axes[7], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
axes[8].axis("off")
axes[8].text(
0.0,
1.0,
build_summary_text(model_name, df, run_name, csv_path),
va="top",
ha="left",
family="monospace",
fontsize=11,
transform=axes[8].transAxes,
)
plt.tight_layout()
plt.savefig(output_path, dpi=160, bbox_inches="tight")
plt.close()
def load_run_model_logs(log_root: str, run_name: str) -> Dict[str, pd.DataFrame]:
run_logs = {}
for model_name in MODEL_ORDER:
csv_path = os.path.join(log_root, run_name, model_name, f"{model_name}_training_log.csv")
if os.path.isfile(csv_path):
try:
run_logs[model_name] = safe_read_csv(csv_path)
except Exception:
continue
return run_logs
def plot_waiting_placeholder(output_path: str, run_name: str, message: str):
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis("off")
ax.text(
0.5,
0.6,
f"Run: {run_name}",
ha="center",
va="center",
fontsize=16,
family="monospace",
transform=ax.transAxes,
)
ax.text(
0.5,
0.4,
message,
ha="center",
va="center",
fontsize=12,
transform=ax.transAxes,
)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def plot_run_comparison(
target_model: str,
run_logs: Dict[str, pd.DataFrame],
compare_output: str,
window: int,
show_ma: bool,
):
if len(run_logs) <= 1:
return
metrics = [
("reward", "Reward"),
("throughput", "Throughput"),
("mean_speed", "Mean Speed"),
("speed_std", "Speed Std"),
]
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()
for ax, (column, title) in zip(axes, metrics):
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or column not in df.columns:
continue
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
continue
linewidth = 2.8 if model_name == target_model else 1.6
alpha = 1.0 if model_name == target_model else 0.5
ax.plot(
df["episode"],
series,
linewidth=linewidth + 0.4,
alpha=alpha,
color=MODEL_COLORS[model_name],
label=MODEL_LABELS[model_name],
)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
linestyle="--",
alpha=min(alpha, 0.4),
color=MODEL_COLORS[model_name],
label="_nolegend_",
)
ax.set_title(title)
ax.set_xlabel("Episode")
ax.grid(True, alpha=0.3)
axes[0].legend(loc="best")
plt.tight_layout()
plt.savefig(compare_output, dpi=160, bbox_inches="tight")
plt.close()
def plot_metric_overlay(
ax,
run_logs: Dict[str, pd.DataFrame],
column: str,
title: str,
ylabel: str,
window: int,
show_ma: bool,
):
plotted = False
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or column not in df.columns:
continue
series = pd.to_numeric(df[column], errors="coerce")
if series.notna().sum() == 0:
continue
ax.plot(
df["episode"],
series,
alpha=0.85,
linewidth=1.8,
color=MODEL_COLORS[model_name],
label=MODEL_LABELS[model_name],
)
if show_ma:
ax.plot(
df["episode"],
moving_average(series, window),
linewidth=1.0,
linestyle="--",
alpha=0.35,
color=MODEL_COLORS[model_name],
label="_nolegend_",
)
plotted = True
if not plotted:
ax.axis("off")
return
ax.set_title(title)
ax.set_xlabel("Episode")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
ax.legend(loc="best")
def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str:
lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""]
for model_name in MODEL_ORDER:
df = run_logs.get(model_name)
if df is None or df.empty:
continue
reward = pd.to_numeric(df.get("reward"), errors="coerce")
throughput = pd.to_numeric(df.get("throughput"), errors="coerce")
latest_episode = int(df["episode"].iloc[-1])
reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan
tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan
lines.append(
f"{MODEL_LABELS[model_name]}: ep={latest_episode}, "
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}"
)
return "\n".join(lines)
def plot_all_models_overview(
run_name: str,
run_logs: Dict[str, pd.DataFrame],
output_path: str,
window: int,
show_ma: bool,
):
if not run_logs:
raise ValueError("No readable model logs found for overview plotting.")
fig, axes = plt.subplots(4, 3, figsize=(22, 16))
axes = axes.flatten()
metrics = [
("reward", "Reward", "Reward"),
("throughput", "Throughput", "veh/h"),
("mean_speed", "Mean Speed", "km/h"),
("speed_std", "Speed Std", "km/h"),
("hard_brakes", "Hard Brakes", "count"),
("r_flow", "R_flow", "value"),
("r_var", "R_var", "value"),
("r_brake", "R_brake", "value"),
("r_penalty", "R_penalty", "value"),
("policy_loss", "Policy Loss", "loss"),
("value_loss", "Value Loss", "loss"),
("entropy", "Entropy", "value"),
]
for ax, (column, title, ylabel) in zip(axes, metrics):
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
fig.suptitle(f"Multi-Model Training Overview | {run_name}", fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.savefig(output_path, dpi=170, bbox_inches="tight")
plt.close()
summary_path = os.path.splitext(output_path)[0] + "_summary.txt"
with open(summary_path, "w", encoding="utf-8") as f:
f.write(build_overview_summary(run_name, run_logs))
def run_once(args):
if args.all_models:
run_name, _, overview_output = resolve_run_dir(args)
run_logs = load_run_model_logs(args.log_root, run_name)
if not run_logs:
plot_waiting_placeholder(
overview_output,
run_name,
"Waiting for training logs to accumulate...",
)
print(
f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | "
f"output={overview_output}"
)
return
plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma)
print(
f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | "
f"models={len(run_logs)} | output={overview_output}"
)
return
model_name = normalize_model_name(args.model)
run_name, _, csv_path, output_path, compare_output = resolve_paths(args, model_name)
df = safe_read_csv(csv_path)
plot_detailed_snapshot(
model_name=model_name,
df=df,
run_name=run_name,
csv_path=csv_path,
output_path=output_path,
window=args.window,
show_ma=args.show_ma,
)
run_logs = load_run_model_logs(args.log_root, run_name)
plot_run_comparison(
target_model=model_name,
run_logs=run_logs,
compare_output=compare_output,
window=args.window,
show_ma=args.show_ma,
)
print(
f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | "
f"episodes={len(df)} | output={output_path}"
)
if len(run_logs) > 1:
print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}")
def main():
args = parse_args()
if args.watch:
try:
while True:
try:
run_once(args)
except Exception as exc:
print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}")
time.sleep(max(args.interval, 1))
except KeyboardInterrupt:
print("Stopped live plotting.")
else:
run_once(args)
if __name__ == "__main__":
main()