ctm-dqn/training/registry.py

86 lines
2.4 KiB
Python

"""Central registry for training entry functions."""
from __future__ import annotations
import importlib
from typing import Callable, Dict, List, Tuple
TRAINER_SPECS: Dict[str, Tuple[str, str]] = {
"no_control": ("training.train_no_control", "train_sumo_no_control"),
"ppo": ("training.train_ppo", "train_sumo_ppo"),
"gpro": ("training.train_gpro", "train_sumo_gpro"),
"appo": ("training.train_appo", "train_sumo_appo"),
"mappo": ("training.train_mappo", "train_sumo_mappo"),
"tcamappo": ("training.train_tcamappo", "train_sumo_tcamappo"),
"dcmappo": ("training.train_dcmappo", "train_sumo_dcmappo"),
"dqn": ("training.train_dqn", "train_sumo_dqn"),
"madqn": ("training.train_madqn", "train_sumo_madqn"),
"ddqn": ("training.train_ddqn", "train_sumo_ddqn"),
"qmix": ("training.train_qmix", "train_sumo_qmix"),
"dcqmix": ("training.train_dcqmix", "train_sumo_dcqmix"),
"ddpg": ("training.train_ddpg", "train_sumo_ddpg"),
"d3pg": ("training.train_d3pg", "train_sumo_d3pg"),
"sac": ("training.train_sac", "train_sumo_sac"),
"td3": ("training.train_td3", "train_sumo_td3"),
"sctd3": ("training.train_sctd3", "train_sumo_sctd3"),
}
DEFAULT_MODELS: List[str] = [
"no_control",
"ppo",
"gpro",
"mappo",
"tcamappo",
"dcmappo",
"dqn",
"madqn",
"ddqn",
"qmix",
"dcqmix",
"ddpg",
"d3pg",
"sac",
"td3",
]
ALL_MODELS: List[str] = list(TRAINER_SPECS.keys())
def get_trainer(model_name: str) -> Callable:
model = normalize_model_name(model_name)
module_name, function_name = TRAINER_SPECS[model]
module = importlib.import_module(module_name)
return getattr(module, function_name)
class _TrainerRegistry(dict):
def __contains__(self, key):
return str(key).strip().lower() in TRAINER_SPECS
def __getitem__(self, key):
return get_trainer(str(key))
def keys(self):
return TRAINER_SPECS.keys()
def items(self):
for key in TRAINER_SPECS:
yield key, get_trainer(key)
TRAINERS: Dict[str, Callable] = _TrainerRegistry()
def normalize_model_name(name: str) -> str:
model = name.strip().lower()
if model not in TRAINER_SPECS:
raise ValueError(f"Unsupported model name: {name}")
return model
def normalize_model_list(models: List[str] | None) -> List[str]:
if not models:
return list(DEFAULT_MODELS)
return [normalize_model_name(model) for model in models]