"""Central registry for training entry functions.""" from __future__ import annotations import importlib from typing import Callable, Dict, List, Tuple TRAINER_SPECS: Dict[str, Tuple[str, str]] = { "no_control": ("training.train_no_control", "train_sumo_no_control"), "ppo": ("training.train_ppo", "train_sumo_ppo"), "gpro": ("training.train_gpro", "train_sumo_gpro"), "appo": ("training.train_appo", "train_sumo_appo"), "mappo": ("training.train_mappo", "train_sumo_mappo"), "tcamappo": ("training.train_tcamappo", "train_sumo_tcamappo"), "dcmappo": ("training.train_dcmappo", "train_sumo_dcmappo"), "dqn": ("training.train_dqn", "train_sumo_dqn"), "madqn": ("training.train_madqn", "train_sumo_madqn"), "ddqn": ("training.train_ddqn", "train_sumo_ddqn"), "qmix": ("training.train_qmix", "train_sumo_qmix"), "dcqmix": ("training.train_dcqmix", "train_sumo_dcqmix"), "ddpg": ("training.train_ddpg", "train_sumo_ddpg"), "d3pg": ("training.train_d3pg", "train_sumo_d3pg"), "sac": ("training.train_sac", "train_sumo_sac"), "td3": ("training.train_td3", "train_sumo_td3"), "sctd3": ("training.train_sctd3", "train_sumo_sctd3"), } DEFAULT_MODELS: List[str] = [ "no_control", "ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", ] ALL_MODELS: List[str] = list(TRAINER_SPECS.keys()) def get_trainer(model_name: str) -> Callable: model = normalize_model_name(model_name) module_name, function_name = TRAINER_SPECS[model] module = importlib.import_module(module_name) return getattr(module, function_name) class _TrainerRegistry(dict): def __contains__(self, key): return str(key).strip().lower() in TRAINER_SPECS def __getitem__(self, key): return get_trainer(str(key)) def keys(self): return TRAINER_SPECS.keys() def items(self): for key in TRAINER_SPECS: yield key, get_trainer(key) TRAINERS: Dict[str, Callable] = _TrainerRegistry() def normalize_model_name(name: str) -> str: model = name.strip().lower() if model not in TRAINER_SPECS: raise ValueError(f"Unsupported model name: {name}") return model def normalize_model_list(models: List[str] | None) -> List[str]: if not models: return list(DEFAULT_MODELS) return [normalize_model_name(model) for model in models]