15 lines
420 B
Python
15 lines
420 B
Python
from agents.sctd3_agent import SCTD3Agent
|
|
from training.train_td3 import train_sumo_td3
|
|
|
|
|
|
def train_sumo_sctd3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
|
return train_sumo_td3(
|
|
log_dir=log_dir,
|
|
checkpoint_dir=checkpoint_dir,
|
|
run_timestamp=run_timestamp,
|
|
model_name="sctd3",
|
|
config_key="sctd3",
|
|
display_name="SC-TD3",
|
|
agent_class=SCTD3Agent,
|
|
)
|