689 lines
22 KiB
Python
689 lines
22 KiB
Python
"""Plot live training snapshots from CSV logs."""
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import time
|
|
from typing import Dict, List, Optional
|
|
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
if PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, PROJECT_ROOT)
|
|
|
|
import matplotlib
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
from envs.reward_system import REWARD_COMPONENT_COLUMNS, REWARD_COMPONENT_LABELS
|
|
from utils.run_dirs import find_latest_run_root, find_run_root_by_timestamp
|
|
|
|
|
|
MODEL_ORDER = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
|
|
MODEL_LABELS = {
|
|
"ppo": "PPO",
|
|
"gpro": "GPRO-PPO",
|
|
"appo": "APPO",
|
|
"mappo": "MAPPO",
|
|
"tcamappo": "TCA-MAPPO",
|
|
"dcmappo": "DC-MAPPO",
|
|
"dqn": "DQN",
|
|
"madqn": "MA-DQN",
|
|
"ddqn": "DDQN",
|
|
"qmix": "QMIX",
|
|
"dcqmix": "DC-QMIX",
|
|
"ddpg": "DDPG",
|
|
"d3pg": "D3PG",
|
|
"sac": "SAC",
|
|
"td3": "TD3",
|
|
"sctd3": "SC-TD3",
|
|
}
|
|
MODEL_COLORS = {
|
|
"ppo": "#1f77b4",
|
|
"gpro": "#6a3d9a",
|
|
"appo": "#ff7f0e",
|
|
"mappo": "#2ca02c",
|
|
"tcamappo": "#7f7f7f",
|
|
"dcmappo": "#8c564b",
|
|
"dqn": "#d62728",
|
|
"madqn": "#ff9896",
|
|
"ddqn": "#ffbb78",
|
|
"qmix": "#8dd3c7",
|
|
"dcqmix": "#2b8cbe",
|
|
"ddpg": "#9467bd",
|
|
"d3pg": "#7fc97f",
|
|
"sac": "#e377c2",
|
|
"td3": "#17becf",
|
|
"sctd3": "#bcbd22",
|
|
}
|
|
EFFICIENCY_COLUMN = "r_efficiency"
|
|
EFFICIENCY_LABEL = REWARD_COMPONENT_LABELS.get(EFFICIENCY_COLUMN, "Running Efficiency")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Plot live training progress from run logs.")
|
|
parser.add_argument("--model", default=None, help="Model name, e.g. ppo/gpro/appo/mappo/tcamappo/dcmappo/dqn/madqn/ddqn/qmix/dcqmix/ddpg/d3pg/sac/td3/sctd3")
|
|
parser.add_argument(
|
|
"--all-models",
|
|
action="store_true",
|
|
help="Plot all models into one overview figure.",
|
|
)
|
|
parser.add_argument(
|
|
"--run",
|
|
default=None,
|
|
help="Run timestamp. Default: latest under runs/<timestamp>/logs.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-root",
|
|
default=None,
|
|
help="Log root directory or run root. Default: latest runs/<timestamp>/logs.",
|
|
)
|
|
parser.add_argument(
|
|
"--csv-path",
|
|
default=None,
|
|
help="Direct training CSV path override.",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default=None,
|
|
help="Detailed snapshot output path. Default: <model_dir>/training_snapshot.png",
|
|
)
|
|
parser.add_argument(
|
|
"--compare-output",
|
|
default=None,
|
|
help="Run comparison output path. Default: <model_dir>/run_comparison_snapshot.png",
|
|
)
|
|
parser.add_argument(
|
|
"--overview-output",
|
|
default=None,
|
|
help="All-model overview output path. Default: <run_dir>/all_models_training_snapshot.png",
|
|
)
|
|
parser.add_argument(
|
|
"--window",
|
|
type=int,
|
|
default=20,
|
|
help="Moving average window.",
|
|
)
|
|
parser.add_argument(
|
|
"--show-ma",
|
|
action="store_true",
|
|
help="Overlay moving-average curves.",
|
|
)
|
|
parser.add_argument(
|
|
"--watch",
|
|
action="store_true",
|
|
help="Continuously refresh plots.",
|
|
)
|
|
parser.add_argument(
|
|
"--interval",
|
|
type=int,
|
|
default=30,
|
|
help="Refresh interval in seconds when --watch is enabled.",
|
|
)
|
|
args = parser.parse_args()
|
|
if not args.all_models and not args.model:
|
|
parser.error("Either --model or --all-models must be specified.")
|
|
return args
|
|
|
|
|
|
def normalize_model_name(name: str) -> str:
|
|
model = name.strip().lower()
|
|
if model not in MODEL_LABELS:
|
|
raise ValueError(f"Unsupported model name: {name}")
|
|
return model
|
|
|
|
|
|
def find_latest_run(log_root: str) -> str:
|
|
if not os.path.isdir(log_root):
|
|
raise FileNotFoundError(f"Log root not found: {log_root}")
|
|
candidates = [
|
|
item for item in os.listdir(log_root)
|
|
if os.path.isdir(os.path.join(log_root, item))
|
|
]
|
|
if not candidates:
|
|
raise FileNotFoundError(f"No run directories found under: {log_root}")
|
|
return max(candidates)
|
|
|
|
|
|
def resolve_run_logs_dir(log_root: Optional[str], run_name: Optional[str]) -> tuple[str, str]:
|
|
if log_root is None:
|
|
run_root = find_run_root_by_timestamp(run_name) if run_name else find_latest_run_root()
|
|
resolved_log_root = (
|
|
os.path.join(run_root, "logs")
|
|
if os.path.isdir(os.path.join(run_root, "logs"))
|
|
else run_root
|
|
)
|
|
else:
|
|
resolved_log_root = os.path.abspath(log_root)
|
|
|
|
if os.path.basename(resolved_log_root).lower() == "logs":
|
|
inferred_run_name = os.path.basename(os.path.dirname(resolved_log_root))
|
|
return inferred_run_name, resolved_log_root
|
|
if os.path.isdir(os.path.join(resolved_log_root, "logs")):
|
|
inferred_run_name = os.path.basename(resolved_log_root)
|
|
return inferred_run_name, os.path.join(resolved_log_root, "logs")
|
|
|
|
chosen_run_name = run_name or find_latest_run(resolved_log_root)
|
|
candidate_run_dir = os.path.join(resolved_log_root, chosen_run_name)
|
|
candidate_logs_dir = os.path.join(candidate_run_dir, "logs")
|
|
if os.path.isdir(candidate_logs_dir):
|
|
return chosen_run_name, candidate_logs_dir
|
|
if os.path.isdir(candidate_run_dir):
|
|
return chosen_run_name, candidate_run_dir
|
|
|
|
raise FileNotFoundError(f"Run directory not found: {candidate_run_dir}")
|
|
|
|
|
|
def resolve_paths(args, model_name: str, run_name: str, run_logs_dir: str):
|
|
if args.csv_path:
|
|
csv_path = args.csv_path
|
|
model_dir = os.path.dirname(os.path.abspath(csv_path))
|
|
else:
|
|
model_dir = os.path.join(run_logs_dir, model_name)
|
|
csv_path = os.path.join(model_dir, f"{model_name}_training_log.csv")
|
|
|
|
if not os.path.isfile(csv_path):
|
|
raise FileNotFoundError(f"Training CSV not found: {csv_path}")
|
|
|
|
output_path = args.output or os.path.join(model_dir, "training_snapshot.png")
|
|
compare_output = args.compare_output or os.path.join(model_dir, "run_comparison_snapshot.png")
|
|
return run_name, model_dir, csv_path, output_path, compare_output
|
|
|
|
|
|
def resolve_run_dir(args, run_name: str, run_logs_dir: str):
|
|
run_dir = (
|
|
os.path.dirname(run_logs_dir)
|
|
if os.path.basename(run_logs_dir).lower() == "logs"
|
|
else run_logs_dir
|
|
)
|
|
overview_output = args.overview_output or os.path.join(run_dir, "all_models_training_snapshot.png")
|
|
return run_name, run_dir, overview_output
|
|
|
|
|
|
def safe_read_csv(csv_path: str) -> pd.DataFrame:
|
|
df = pd.read_csv(csv_path, on_bad_lines="skip")
|
|
if df.empty:
|
|
raise ValueError(f"CSV exists but contains no rows: {csv_path}")
|
|
|
|
numeric_columns = [
|
|
"episode", "reward", "throughput", "mean_speed", "speed_variance_norm",
|
|
*REWARD_COMPONENT_COLUMNS,
|
|
"stops",
|
|
"policy_loss", "value_loss", "entropy",
|
|
]
|
|
for column in numeric_columns:
|
|
if column in df.columns:
|
|
df[column] = pd.to_numeric(df[column], errors="coerce")
|
|
|
|
df = df.dropna(subset=["episode"])
|
|
df = df.sort_values("episode").drop_duplicates(subset=["episode"], keep="last")
|
|
return df.reset_index(drop=True)
|
|
|
|
|
|
def moving_average(values: pd.Series, window: int) -> pd.Series:
|
|
return values.rolling(window=min(window, max(len(values), 1)), min_periods=1).mean()
|
|
|
|
|
|
def plot_series(
|
|
ax,
|
|
df: pd.DataFrame,
|
|
column: str,
|
|
title: str,
|
|
ylabel: str,
|
|
color: str,
|
|
window: int,
|
|
show_ma: bool,
|
|
):
|
|
if column not in df.columns:
|
|
ax.axis("off")
|
|
return
|
|
|
|
series = pd.to_numeric(df[column], errors="coerce")
|
|
if series.notna().sum() == 0:
|
|
ax.axis("off")
|
|
return
|
|
|
|
raw_label = "Series" if show_ma else None
|
|
ax.plot(df["episode"], series, color=color, alpha=0.9, linewidth=1.8, label=raw_label)
|
|
if show_ma:
|
|
ax.plot(
|
|
df["episode"],
|
|
moving_average(series, window),
|
|
color=color,
|
|
alpha=0.35,
|
|
linewidth=1.0,
|
|
linestyle="--",
|
|
label=f"MA{window}",
|
|
)
|
|
ax.set_title(title)
|
|
ax.set_xlabel("Episode")
|
|
ax.set_ylabel(ylabel)
|
|
ax.grid(True, alpha=0.3)
|
|
if show_ma:
|
|
ax.legend(loc="best")
|
|
|
|
|
|
def build_summary_text(model_name: str, df: pd.DataFrame, run_name: str, csv_path: str) -> str:
|
|
last_row = df.iloc[-1]
|
|
reward_last20 = df["reward"].tail(min(20, len(df))).mean() if "reward" in df else np.nan
|
|
tp_last20 = df["throughput"].tail(min(20, len(df))).mean() if "throughput" in df else np.nan
|
|
speed_last20 = df["mean_speed"].tail(min(20, len(df))).mean() if "mean_speed" in df else np.nan
|
|
efficiency_last20 = df[EFFICIENCY_COLUMN].tail(min(20, len(df))).mean() if EFFICIENCY_COLUMN in df else np.nan
|
|
stop_last20 = df["stops"].tail(min(20, len(df))).mean() if "stops" in df else np.nan
|
|
|
|
return (
|
|
f"Run: {run_name}\n"
|
|
f"Model: {MODEL_LABELS[model_name]}\n"
|
|
f"Episodes logged: {len(df)}\n"
|
|
f"Latest episode: {int(last_row['episode'])}\n"
|
|
f"Best reward: {df['reward'].max():.2f}\n"
|
|
f"Latest reward: {last_row.get('reward', np.nan):.2f}\n"
|
|
f"Reward last20: {reward_last20:.2f}\n"
|
|
f"Throughput last20: {tp_last20:.1f}\n"
|
|
f"Mean speed last20: {speed_last20:.1f}\n"
|
|
f"Efficiency last20: {efficiency_last20:.3f}\n"
|
|
f"Stops last20: {stop_last20:.1f}\n"
|
|
f"\nCSV:\n{os.path.abspath(csv_path)}"
|
|
)
|
|
|
|
|
|
def plot_detailed_snapshot(
|
|
model_name: str,
|
|
df: pd.DataFrame,
|
|
run_name: str,
|
|
csv_path: str,
|
|
output_path: str,
|
|
window: int,
|
|
show_ma: bool,
|
|
):
|
|
fig, axes = plt.subplots(4, 3, figsize=(18, 15))
|
|
axes = axes.flatten()
|
|
|
|
plot_series(axes[0], df, "reward", "Reward", "Reward", "tab:blue", window, show_ma)
|
|
plot_series(axes[1], df, "throughput", "Throughput", "veh/h", "tab:green", window, show_ma)
|
|
plot_series(axes[2], df, "mean_speed", "Mean Speed", "km/h", "tab:orange", window, show_ma)
|
|
plot_series(
|
|
axes[3],
|
|
df,
|
|
EFFICIENCY_COLUMN,
|
|
EFFICIENCY_LABEL,
|
|
"value",
|
|
"tab:olive",
|
|
window,
|
|
show_ma,
|
|
)
|
|
plot_series(
|
|
axes[4],
|
|
df,
|
|
"speed_variance_norm",
|
|
"Normalized Speed Variance",
|
|
"norm",
|
|
"tab:purple",
|
|
window,
|
|
show_ma,
|
|
)
|
|
plot_series(axes[5], df, "stops", "Stops", "count", "tab:red", window, show_ma)
|
|
|
|
reward_components = [column for column in REWARD_COMPONENT_COLUMNS if column != EFFICIENCY_COLUMN]
|
|
has_components = any(
|
|
col in df.columns and pd.to_numeric(df[col], errors="coerce").notna().sum() > 0
|
|
for col in reward_components
|
|
)
|
|
if has_components:
|
|
component_colors = [
|
|
"tab:green",
|
|
"tab:blue",
|
|
"tab:orange",
|
|
"tab:purple",
|
|
"tab:red",
|
|
"tab:brown",
|
|
"tab:gray",
|
|
]
|
|
for col, color in zip(reward_components, component_colors):
|
|
series = pd.to_numeric(df[col], errors="coerce")
|
|
if series.notna().sum() > 0:
|
|
axes[6].plot(
|
|
df["episode"],
|
|
series,
|
|
linewidth=1.6,
|
|
alpha=0.75,
|
|
color=color,
|
|
label=col,
|
|
)
|
|
if show_ma:
|
|
axes[6].plot(
|
|
df["episode"],
|
|
moving_average(series, window),
|
|
linewidth=1.0,
|
|
alpha=0.35,
|
|
linestyle="--",
|
|
color=color,
|
|
label="_nolegend_",
|
|
)
|
|
axes[6].set_title("Reward Components")
|
|
axes[6].set_xlabel("Episode")
|
|
axes[6].grid(True, alpha=0.3)
|
|
axes[6].legend(loc="best")
|
|
else:
|
|
axes[6].axis("off")
|
|
|
|
plot_series(axes[7], df, "policy_loss", "Policy Loss", "loss", "tab:cyan", window, show_ma)
|
|
plot_series(axes[8], df, "value_loss", "Value Loss", "loss", "tab:pink", window, show_ma)
|
|
plot_series(axes[9], df, "entropy", "Entropy", "value", "tab:brown", window, show_ma)
|
|
|
|
axes[10].axis("off")
|
|
axes[11].axis("off")
|
|
axes[11].text(
|
|
0.0,
|
|
1.0,
|
|
build_summary_text(model_name, df, run_name, csv_path),
|
|
va="top",
|
|
ha="left",
|
|
family="monospace",
|
|
fontsize=11,
|
|
transform=axes[11].transAxes,
|
|
)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=160, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def load_run_model_logs(run_logs_dir: str) -> Dict[str, pd.DataFrame]:
|
|
run_logs = {}
|
|
for model_name in MODEL_ORDER:
|
|
csv_path = os.path.join(run_logs_dir, model_name, f"{model_name}_training_log.csv")
|
|
if os.path.isfile(csv_path):
|
|
try:
|
|
run_logs[model_name] = safe_read_csv(csv_path)
|
|
except Exception:
|
|
continue
|
|
return run_logs
|
|
|
|
|
|
def plot_waiting_placeholder(output_path: str, run_name: str, message: str):
|
|
fig, ax = plt.subplots(figsize=(10, 4))
|
|
ax.axis("off")
|
|
ax.text(
|
|
0.5,
|
|
0.6,
|
|
f"Run: {run_name}",
|
|
ha="center",
|
|
va="center",
|
|
fontsize=16,
|
|
family="monospace",
|
|
transform=ax.transAxes,
|
|
)
|
|
ax.text(
|
|
0.5,
|
|
0.4,
|
|
message,
|
|
ha="center",
|
|
va="center",
|
|
fontsize=12,
|
|
transform=ax.transAxes,
|
|
)
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def plot_run_comparison(
|
|
target_model: str,
|
|
run_logs: Dict[str, pd.DataFrame],
|
|
compare_output: str,
|
|
window: int,
|
|
show_ma: bool,
|
|
):
|
|
if len(run_logs) <= 1:
|
|
return
|
|
|
|
metrics = [
|
|
("reward", "Reward"),
|
|
("throughput", "Throughput"),
|
|
("mean_speed", "Mean Speed"),
|
|
(EFFICIENCY_COLUMN, EFFICIENCY_LABEL),
|
|
("speed_variance_norm", "Normalized Speed Variance"),
|
|
]
|
|
fig, axes = plt.subplots(3, 2, figsize=(16, 14))
|
|
axes = axes.flatten()
|
|
|
|
for ax, (column, title) in zip(axes, metrics):
|
|
for model_name in MODEL_ORDER:
|
|
df = run_logs.get(model_name)
|
|
if df is None or column not in df.columns:
|
|
continue
|
|
series = pd.to_numeric(df[column], errors="coerce")
|
|
if series.notna().sum() == 0:
|
|
continue
|
|
linewidth = 2.8 if model_name == target_model else 1.6
|
|
alpha = 1.0 if model_name == target_model else 0.5
|
|
ax.plot(
|
|
df["episode"],
|
|
series,
|
|
linewidth=linewidth + 0.4,
|
|
alpha=alpha,
|
|
color=MODEL_COLORS[model_name],
|
|
label=MODEL_LABELS[model_name],
|
|
)
|
|
if show_ma:
|
|
ax.plot(
|
|
df["episode"],
|
|
moving_average(series, window),
|
|
linewidth=1.0,
|
|
linestyle="--",
|
|
alpha=min(alpha, 0.4),
|
|
color=MODEL_COLORS[model_name],
|
|
label="_nolegend_",
|
|
)
|
|
ax.set_title(title)
|
|
ax.set_xlabel("Episode")
|
|
ax.grid(True, alpha=0.3)
|
|
axes[0].legend(loc="best")
|
|
for ax in axes[len(metrics):]:
|
|
ax.axis("off")
|
|
plt.tight_layout()
|
|
plt.savefig(compare_output, dpi=160, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def plot_metric_overlay(
|
|
ax,
|
|
run_logs: Dict[str, pd.DataFrame],
|
|
column: str,
|
|
title: str,
|
|
ylabel: str,
|
|
window: int,
|
|
show_ma: bool,
|
|
):
|
|
plotted = False
|
|
for model_name in MODEL_ORDER:
|
|
df = run_logs.get(model_name)
|
|
if df is None or column not in df.columns:
|
|
continue
|
|
series = pd.to_numeric(df[column], errors="coerce")
|
|
if series.notna().sum() == 0:
|
|
continue
|
|
ax.plot(
|
|
df["episode"],
|
|
series,
|
|
alpha=0.85,
|
|
linewidth=1.8,
|
|
color=MODEL_COLORS[model_name],
|
|
label=MODEL_LABELS[model_name],
|
|
)
|
|
if show_ma:
|
|
ax.plot(
|
|
df["episode"],
|
|
moving_average(series, window),
|
|
linewidth=1.0,
|
|
linestyle="--",
|
|
alpha=0.35,
|
|
color=MODEL_COLORS[model_name],
|
|
label="_nolegend_",
|
|
)
|
|
plotted = True
|
|
|
|
if not plotted:
|
|
ax.axis("off")
|
|
return
|
|
|
|
ax.set_title(title)
|
|
ax.set_xlabel("Episode")
|
|
ax.set_ylabel(ylabel)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.legend(loc="best")
|
|
|
|
|
|
def build_overview_summary(run_name: str, run_logs: Dict[str, pd.DataFrame]) -> str:
|
|
lines = [f"Run: {run_name}", f"Models: {len(run_logs)}", ""]
|
|
for model_name in MODEL_ORDER:
|
|
df = run_logs.get(model_name)
|
|
if df is None or df.empty:
|
|
continue
|
|
reward = pd.to_numeric(df.get("reward"), errors="coerce")
|
|
throughput = pd.to_numeric(df.get("throughput"), errors="coerce")
|
|
efficiency = pd.to_numeric(df.get(EFFICIENCY_COLUMN), errors="coerce")
|
|
latest_episode = int(df["episode"].iloc[-1])
|
|
reward_last20 = reward.tail(min(20, len(df))).mean() if reward.notna().sum() > 0 else np.nan
|
|
tp_last20 = throughput.tail(min(20, len(df))).mean() if throughput.notna().sum() > 0 else np.nan
|
|
eff_last20 = efficiency.tail(min(20, len(df))).mean() if efficiency.notna().sum() > 0 else np.nan
|
|
lines.append(
|
|
f"{MODEL_LABELS[model_name]}: ep={latest_episode}, "
|
|
f"reward20={reward_last20:.1f}, tp20={tp_last20:.0f}, eff20={eff_last20:.3f}"
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
def plot_all_models_overview(
|
|
run_name: str,
|
|
run_logs: Dict[str, pd.DataFrame],
|
|
output_path: str,
|
|
window: int,
|
|
show_ma: bool,
|
|
):
|
|
if not run_logs:
|
|
raise ValueError("No readable model logs found for overview plotting.")
|
|
|
|
available_reward_components = []
|
|
for column in REWARD_COMPONENT_COLUMNS:
|
|
if column == EFFICIENCY_COLUMN:
|
|
continue
|
|
if any(
|
|
df is not None
|
|
and column in df.columns
|
|
and pd.to_numeric(df[column], errors="coerce").notna().sum() > 0
|
|
for df in run_logs.values()
|
|
):
|
|
available_reward_components.append(column)
|
|
available_reward_components = available_reward_components[:4]
|
|
|
|
fig, axes = plt.subplots(4, 3, figsize=(22, 16))
|
|
axes = axes.flatten()
|
|
metrics = [
|
|
("reward", "Reward", "Reward"),
|
|
("throughput", "Throughput", "veh/h"),
|
|
("mean_speed", "Mean Speed", "km/h"),
|
|
(EFFICIENCY_COLUMN, EFFICIENCY_LABEL, "value"),
|
|
("speed_variance_norm", "Normalized Speed Variance", "norm"),
|
|
("stops", "Stops", "count"),
|
|
("policy_loss", "Policy Loss", "loss"),
|
|
("value_loss", "Value Loss", "loss"),
|
|
("entropy", "Entropy", "value"),
|
|
]
|
|
reward_metrics = [
|
|
(column, REWARD_COMPONENT_LABELS.get(column, column), "value")
|
|
for column in available_reward_components
|
|
]
|
|
metrics = metrics[:5] + reward_metrics + metrics[5:]
|
|
|
|
for ax, (column, title, ylabel) in zip(axes, metrics):
|
|
plot_metric_overlay(ax, run_logs, column, title, ylabel, window, show_ma)
|
|
for ax in axes[len(metrics):]:
|
|
ax.axis("off")
|
|
|
|
fig.suptitle(f"Training Overview | {run_name}", fontsize=16, y=0.995)
|
|
plt.tight_layout(rect=[0, 0, 1, 0.98])
|
|
plt.savefig(output_path, dpi=170, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
summary_path = os.path.splitext(output_path)[0] + "_summary.txt"
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
f.write(build_overview_summary(run_name, run_logs))
|
|
|
|
|
|
def run_once(args):
|
|
run_name, run_logs_dir = resolve_run_logs_dir(args.log_root, args.run)
|
|
if args.all_models:
|
|
_, _, overview_output = resolve_run_dir(args, run_name, run_logs_dir)
|
|
run_logs = load_run_model_logs(run_logs_dir)
|
|
if not run_logs:
|
|
plot_waiting_placeholder(
|
|
overview_output,
|
|
run_name,
|
|
"Waiting for training logs to accumulate...",
|
|
)
|
|
print(
|
|
f"[{time.strftime('%H:%M:%S')}] Waiting for model logs | "
|
|
f"output={overview_output}"
|
|
)
|
|
return
|
|
plot_all_models_overview(run_name, run_logs, overview_output, args.window, args.show_ma)
|
|
print(
|
|
f"[{time.strftime('%H:%M:%S')}] Updated all-model overview | "
|
|
f"models={len(run_logs)} | output={overview_output}"
|
|
)
|
|
return
|
|
|
|
model_name = normalize_model_name(args.model)
|
|
run_name, _, csv_path, output_path, compare_output = resolve_paths(
|
|
args, model_name, run_name, run_logs_dir
|
|
)
|
|
df = safe_read_csv(csv_path)
|
|
plot_detailed_snapshot(
|
|
model_name=model_name,
|
|
df=df,
|
|
run_name=run_name,
|
|
csv_path=csv_path,
|
|
output_path=output_path,
|
|
window=args.window,
|
|
show_ma=args.show_ma,
|
|
)
|
|
|
|
run_logs = load_run_model_logs(run_logs_dir)
|
|
plot_run_comparison(
|
|
target_model=model_name,
|
|
run_logs=run_logs,
|
|
compare_output=compare_output,
|
|
window=args.window,
|
|
show_ma=args.show_ma,
|
|
)
|
|
|
|
print(
|
|
f"[{time.strftime('%H:%M:%S')}] Updated {MODEL_LABELS[model_name]} snapshot | "
|
|
f"episodes={len(df)} | output={output_path}"
|
|
)
|
|
if len(run_logs) > 1:
|
|
print(f"[{time.strftime('%H:%M:%S')}] Updated run comparison | output={compare_output}")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if args.watch:
|
|
try:
|
|
while True:
|
|
try:
|
|
run_once(args)
|
|
except Exception as exc:
|
|
print(f"[{time.strftime('%H:%M:%S')}] Plot update skipped: {exc}")
|
|
time.sleep(max(args.interval, 1))
|
|
except KeyboardInterrupt:
|
|
print("Stopped live plotting.")
|
|
else:
|
|
run_once(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|