66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
"""Central registry for training entry functions."""
|
|
from typing import Callable, Dict, List
|
|
|
|
from training.train_appo import train_sumo_appo
|
|
from training.train_dcmappo import train_sumo_dcmappo
|
|
from training.train_dcqmix import train_sumo_dcqmix
|
|
from training.train_d3pg import train_sumo_d3pg
|
|
from training.train_ddqn import train_sumo_ddqn
|
|
from training.train_ddpg import train_sumo_ddpg
|
|
from training.train_dqn import train_sumo_dqn
|
|
from training.train_gpro import train_sumo_gpro
|
|
from training.train_madqn import train_sumo_madqn
|
|
from training.train_mappo import train_sumo_mappo
|
|
from training.train_ppo import train_sumo_ppo
|
|
from training.train_qmix import train_sumo_qmix
|
|
from training.train_rule_vsl import (
|
|
train_sumo_bottleneck_rule_vsl,
|
|
train_sumo_harmonization_rule_vsl,
|
|
train_sumo_occ_rule_vsl,
|
|
)
|
|
from training.train_sac import train_sumo_sac
|
|
from training.train_sctd3 import train_sumo_sctd3
|
|
from training.train_tacmappo import train_sumo_tacmappo
|
|
from training.train_td3 import train_sumo_td3
|
|
|
|
|
|
# DEFAULT_MODELS: List[str] = ["ppo"]
|
|
DEFAULT_MODELS: List[str] = ["occ_rule_vsl", "bottleneck_rule_vsl", "harmonization_rule_vsl", "ppo", "gpro", "mappo", "tacmappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3"]
|
|
ALL_MODELS: List[str] = ["occ_rule_vsl", "bottleneck_rule_vsl", "harmonization_rule_vsl", "ppo", "gpro", "appo", "mappo", "tacmappo", "dcmappo", "dqn", "madqn", "ddqn", "qmix", "dcqmix", "ddpg", "d3pg", "sac", "td3", "sctd3"]
|
|
|
|
|
|
TRAINERS: Dict[str, Callable] = {
|
|
"occ_rule_vsl": train_sumo_occ_rule_vsl,
|
|
"bottleneck_rule_vsl": train_sumo_bottleneck_rule_vsl,
|
|
"harmonization_rule_vsl": train_sumo_harmonization_rule_vsl,
|
|
"ppo": train_sumo_ppo,
|
|
"gpro": train_sumo_gpro,
|
|
"appo": train_sumo_appo,
|
|
"mappo": train_sumo_mappo,
|
|
"tacmappo": train_sumo_tacmappo,
|
|
"dcmappo": train_sumo_dcmappo,
|
|
"dqn": train_sumo_dqn,
|
|
"madqn": train_sumo_madqn,
|
|
"ddqn": train_sumo_ddqn,
|
|
"qmix": train_sumo_qmix,
|
|
"dcqmix": train_sumo_dcqmix,
|
|
"ddpg": train_sumo_ddpg,
|
|
"d3pg": train_sumo_d3pg,
|
|
"sac": train_sumo_sac,
|
|
"td3": train_sumo_td3,
|
|
"sctd3": train_sumo_sctd3,
|
|
}
|
|
|
|
|
|
def normalize_model_name(name: str) -> str:
|
|
model = name.strip().lower()
|
|
if model not in TRAINERS:
|
|
raise ValueError(f"Unsupported model name: {name}")
|
|
return model
|
|
|
|
|
|
def normalize_model_list(models: List[str] | None) -> List[str]:
|
|
if not models:
|
|
return list(DEFAULT_MODELS)
|
|
return [normalize_model_name(model) for model in models]
|