86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
"""Central registry for training entry functions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
from typing import Callable, Dict, List, Tuple
|
|
|
|
|
|
TRAINER_SPECS: Dict[str, Tuple[str, str]] = {
|
|
"no_control": ("training.train_no_control", "train_sumo_no_control"),
|
|
"ppo": ("training.train_ppo", "train_sumo_ppo"),
|
|
"gpro": ("training.train_gpro", "train_sumo_gpro"),
|
|
"appo": ("training.train_appo", "train_sumo_appo"),
|
|
"mappo": ("training.train_mappo", "train_sumo_mappo"),
|
|
"tcamappo": ("training.train_tcamappo", "train_sumo_tcamappo"),
|
|
"dcmappo": ("training.train_dcmappo", "train_sumo_dcmappo"),
|
|
"dqn": ("training.train_dqn", "train_sumo_dqn"),
|
|
"madqn": ("training.train_madqn", "train_sumo_madqn"),
|
|
"ddqn": ("training.train_ddqn", "train_sumo_ddqn"),
|
|
"qmix": ("training.train_qmix", "train_sumo_qmix"),
|
|
"dcqmix": ("training.train_dcqmix", "train_sumo_dcqmix"),
|
|
"ddpg": ("training.train_ddpg", "train_sumo_ddpg"),
|
|
"d3pg": ("training.train_d3pg", "train_sumo_d3pg"),
|
|
"sac": ("training.train_sac", "train_sumo_sac"),
|
|
"td3": ("training.train_td3", "train_sumo_td3"),
|
|
"sctd3": ("training.train_sctd3", "train_sumo_sctd3"),
|
|
}
|
|
|
|
|
|
DEFAULT_MODELS: List[str] = [
|
|
"no_control",
|
|
"ppo",
|
|
"gpro",
|
|
"mappo",
|
|
"tcamappo",
|
|
"dcmappo",
|
|
"dqn",
|
|
"madqn",
|
|
"ddqn",
|
|
"qmix",
|
|
"dcqmix",
|
|
"ddpg",
|
|
"d3pg",
|
|
"sac",
|
|
"td3",
|
|
]
|
|
ALL_MODELS: List[str] = list(TRAINER_SPECS.keys())
|
|
|
|
|
|
def get_trainer(model_name: str) -> Callable:
|
|
model = normalize_model_name(model_name)
|
|
module_name, function_name = TRAINER_SPECS[model]
|
|
module = importlib.import_module(module_name)
|
|
return getattr(module, function_name)
|
|
|
|
|
|
class _TrainerRegistry(dict):
|
|
def __contains__(self, key):
|
|
return str(key).strip().lower() in TRAINER_SPECS
|
|
|
|
def __getitem__(self, key):
|
|
return get_trainer(str(key))
|
|
|
|
def keys(self):
|
|
return TRAINER_SPECS.keys()
|
|
|
|
def items(self):
|
|
for key in TRAINER_SPECS:
|
|
yield key, get_trainer(key)
|
|
|
|
|
|
TRAINERS: Dict[str, Callable] = _TrainerRegistry()
|
|
|
|
|
|
def normalize_model_name(name: str) -> str:
|
|
model = name.strip().lower()
|
|
if model not in TRAINER_SPECS:
|
|
raise ValueError(f"Unsupported model name: {name}")
|
|
return model
|
|
|
|
|
|
def normalize_model_list(models: List[str] | None) -> List[str]:
|
|
if not models:
|
|
return list(DEFAULT_MODELS)
|
|
return [normalize_model_name(model) for model in models]
|