diff --git a/training/registry.py b/training/registry.py index 94669e8..2531c1d 100644 --- a/training/registry.py +++ b/training/registry.py @@ -16,8 +16,8 @@ from training.train_tcamappo import train_sumo_tcamappo from training.train_td3 import train_sumo_td3 -DEFAULT_MODELS: List[str] = ["ppo"] -# DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "ddpg", "sac", "td3"] +# DEFAULT_MODELS: List[str] = ["ppo"] +DEFAULT_MODELS: List[str] = ["ppo", "gpro", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "ddpg", "sac", "td3"] ALL_MODELS: List[str] = ["ppo", "gpro", "appo", "mappo", "tcamappo", "dcmappo", "dqn", "madqn", "ddqn", "ddpg", "sac", "td3", "sctd3"]