"""Central registry for training entry functions.""" from typing import Callable, Dict, List from training.train_appo import train_sumo_appo from training.train_dcmappo import train_sumo_dcmappo from training.train_dcqmix import train_sumo_dcqmix from training.train_d3pg import train_sumo_d3pg from training.train_ddqn import train_sumo_ddqn from training.train_ddpg import train_sumo_ddpg from training.train_dqn import train_sumo_dqn from training.train_gpro import train_sumo_gpro from training.train_madqn import train_sumo_madqn from training.train_mappo import train_sumo_mappo from training.train_ppo import train_sumo_ppo from training.train_qmix import train_sumo_qmix from training.train_sac import train_sumo_sac from training.train_sctd3 import train_sumo_sctd3 from training.train_tcamappo import train_sumo_tcamappo from training.train_td3 import train_sumo_td3 # DEFAULT_MODELS: List[str] = ["ppo"] DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3"] ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"] TRAINERS: Dict[str, Callable] = { "ppo": train_sumo_ppo, "gpro": train_sumo_gpro, "appo": train_sumo_appo, "mappo": train_sumo_mappo, "tcamappo": train_sumo_tcamappo, "dcmappo": train_sumo_dcmappo, "dqn": train_sumo_dqn, "madqn": train_sumo_madqn, "ddqn": train_sumo_ddqn, "qmix": train_sumo_qmix, "dcqmix": train_sumo_dcqmix, "ddpg": train_sumo_ddpg, "d3pg": train_sumo_d3pg, "sac": train_sumo_sac, "td3": train_sumo_td3, "sctd3": train_sumo_sctd3, } def normalize_model_name(name: str) -> str: model = name.strip().lower() if model not in TRAINERS: raise ValueError(f"Unsupported model name: {name}") return model def normalize_model_list(models: List[str] | None) -> List[str]: if not models: return list(DEFAULT_MODELS) return [normalize_model_name(model) for model in models]