ctm-dqn/training/train_sctd3.py

15 lines
420 B
Python

from agents.sctd3_agent import SCTD3Agent
from training.train_td3 import train_sumo_td3
def train_sumo_sctd3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
return train_sumo_td3(
log_dir=log_dir,
checkpoint_dir=checkpoint_dir,
run_timestamp=run_timestamp,
model_name="sctd3",
config_key="sctd3",
display_name="SC-TD3",
agent_class=SCTD3Agent,
)