ctm-dqn/training/registry.py

56 lines
2.0 KiB
Python

"""Central registry for training entry functions."""
from typing import Callable, Dict, List
from training.train_appo import train_sumo_appo
from training.train_dcmappo import train_sumo_dcmappo
from training.train_dcqmix import train_sumo_dcqmix
from training.train_ddqn import train_sumo_ddqn
from training.train_ddpg import train_sumo_ddpg
from training.train_dqn import train_sumo_dqn
from training.train_gpro import train_sumo_gpro
from training.train_madqn import train_sumo_madqn
from training.train_mappo import train_sumo_mappo
from training.train_ppo import train_sumo_ppo
from training.train_qmix import train_sumo_qmix
from training.train_sac import train_sumo_sac
from training.train_sctd3 import train_sumo_sctd3
from training.train_tcamappo import train_sumo_tcamappo
from training.train_td3 import train_sumo_td3
# DEFAULT_MODELS: List[str] = ["ppo"]
DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3"]
ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "sac", "td3", "sctd3"]
TRAINERS: Dict[str, Callable] = {
"ppo": train_sumo_ppo,
"gpro": train_sumo_gpro,
"appo": train_sumo_appo,
"mappo": train_sumo_mappo,
"tcamappo": train_sumo_tcamappo,
"dcmappo": train_sumo_dcmappo,
"dqn": train_sumo_dqn,
"madqn": train_sumo_madqn,
"ddqn": train_sumo_ddqn,
"qmix": train_sumo_qmix,
"dcqmix": train_sumo_dcqmix,
"ddpg": train_sumo_ddpg,
"sac": train_sumo_sac,
"td3": train_sumo_td3,
"sctd3": train_sumo_sctd3,
}
def normalize_model_name(name: str) -> str:
model = name.strip().lower()
if model not in TRAINERS:
raise ValueError(f"Unsupported model name: {name}")
return model
def normalize_model_list(models: List[str] | None) -> List[str]:
if not models:
return list(DEFAULT_MODELS)
return [normalize_model_name(model) for model in models]