diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 8d35ec4..be4072e 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -110,6 +110,12 @@ def parse_args(): default=None, help="Override SUMO simulation step length for evaluation only. Default: use training config.", ) + parser.add_argument( + "--route-file", + type=str, + default=None, + help="Override SUMO route/flow file for evaluation only. Supports absolute paths or project-relative paths.", + ) return parser.parse_args() @@ -221,6 +227,14 @@ def deep_merge_dicts(base: dict, override: dict) -> dict: return merged +def resolve_project_path(path_str: Optional[str]) -> Optional[str]: + if not path_str: + return None + if os.path.isabs(path_str): + return path_str + return os.path.abspath(os.path.join(PROJECT_ROOT, path_str)) + + def resolve_model_load_path(model_name: str, checkpoint_dir: str) -> str: if model_name in {"ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn"}: best_path = os.path.join(checkpoint_dir, "model_best.pt") @@ -482,6 +496,7 @@ def evaluate_single_model( end_time: int, with_gui: bool, step_length: Optional[float], + route_file: Optional[str], ) -> Tuple[pd.DataFrame, pd.DataFrame, dict]: config = load_config_for_checkpoint(checkpoint_dir, fallback_config_path) if checkpoint_dir else load_config_for_checkpoint("", fallback_config_path) runtime_config = copy.deepcopy(config) @@ -497,6 +512,8 @@ def evaluate_single_model( runtime_config["sumo"]["gui"] = with_gui if step_length is not None: runtime_config["sumo"]["step_length"] = step_length + if route_file is not None: + runtime_config["sumo"]["route_file"] = route_file runtime_config.setdefault("runtime", {})["output_dir"] = os.path.join(output_dir, model_name) runtime_config["runtime"]["metrics_subdir"] = "eval_sumo_metrics" runtime_config["runtime"]["collect_detector_cells"] = True @@ -631,11 +648,12 @@ def evaluate_single_model( "end_time": effective_end_time, "with_gui": with_gui, "step_length": runtime_config["sumo"].get("step_length"), + "route_file": runtime_config["sumo"].get("route_file", ""), } return step_df, edge_df, detector_df, meta -def evaluate_worker(task: Tuple[str, Optional[str], str, str, int, Optional[int], Optional[int], bool, Optional[float]]): +def evaluate_worker(task: Tuple[str, Optional[str], str, str, int, Optional[int], Optional[int], bool, Optional[float], Optional[str]]): return evaluate_single_model(*task) @@ -873,6 +891,9 @@ def print_summary(summary_df: pd.DataFrame, output_dir: str): def main(): args = parse_args() + route_file = resolve_project_path(args.route_file) + if route_file is not None and not os.path.isfile(route_file): + raise FileNotFoundError(f"Custom route file not found: {route_file}") checkpoint_root = resolve_checkpoint_root(args.checkpoint_root) model_dirs = discover_model_dirs(checkpoint_root, args.models) if not model_dirs: @@ -904,6 +925,7 @@ def main(): args.end_time, args.with_gui, args.step_length, + route_file, ) ) @@ -922,6 +944,7 @@ def main(): args.end_time, args.with_gui, args.step_length, + route_file, ), )