修改同时训练时保存的目录
This commit is contained in:
parent
09c0dd86e5
commit
27502241ad
|
|
@ -1,24 +1,50 @@
|
||||||
"""一键异步启动全部训练"""
|
"""一键异步启动全部训练。"""
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
AGENTS = ["ppo", "appo", "mappo", "dqn", "ddpg", "td3"]
|
||||||
|
|
||||||
processes = {}
|
|
||||||
for agent in AGENTS:
|
|
||||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
|
|
||||||
p = subprocess.Popen(
|
|
||||||
[sys.executable, "-m", f"training.train_{agent}"],
|
|
||||||
stdout=open(f"logs_{agent}_stdout.txt", "w"),
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
)
|
|
||||||
processes[agent] = p
|
|
||||||
print(f" PID: {p.pid}")
|
|
||||||
|
|
||||||
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...\n")
|
def main():
|
||||||
|
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
processes = {}
|
||||||
|
|
||||||
for agent, p in processes.items():
|
for agent in AGENTS:
|
||||||
p.wait()
|
print(f"[{datetime.now().strftime('%H:%M:%S')}] 启动 {agent.upper()} 训练...")
|
||||||
status = "✓ 完成" if p.returncode == 0 else f"✗ 失败(code={p.returncode})"
|
log_dir = os.path.join("logs", "multi-model", run_timestamp, agent)
|
||||||
print(f"[{agent.upper()}] {status}")
|
checkpoint_dir = os.path.join("checkpoints", "multi-model", run_timestamp, agent)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
stdout_path = os.path.join(log_dir, "stdout.txt")
|
||||||
|
|
||||||
|
process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
f"training.train_{agent}",
|
||||||
|
"--log-dir",
|
||||||
|
log_dir,
|
||||||
|
"--checkpoint-dir",
|
||||||
|
checkpoint_dir,
|
||||||
|
"--run-timestamp",
|
||||||
|
run_timestamp,
|
||||||
|
],
|
||||||
|
stdout=open(stdout_path, "w", encoding="utf-8"),
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
processes[agent] = process
|
||||||
|
print(f" PID: {process.pid}")
|
||||||
|
|
||||||
|
print(f"\n全部 {len(AGENTS)} 个训练已启动,等待完成...")
|
||||||
|
print(f"本次多模型时间戳: {run_timestamp}\n")
|
||||||
|
|
||||||
|
for agent, process in processes.items():
|
||||||
|
process.wait()
|
||||||
|
status = "完成" if process.returncode == 0 else f"失败(code={process.returncode})"
|
||||||
|
print(f"[{agent.upper()}] {status}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
基于 SUMO+TraCI 的 APPO 训练脚本
|
基于 SUMO+TraCI 的 APPO 训练脚本
|
||||||
使用微观仿真环境训练 VSL 控制策略
|
使用微观仿真环境训练 VSL 控制策略
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -19,9 +20,10 @@ from agents.appo_agent import APPOAgent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_appo():
|
def train_sumo_appo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
"""SUMO 环境下的 APPO 训练主函数"""
|
"""SUMO 环境下的 APPO 训练主函数"""
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
@ -30,9 +32,12 @@ def train_sumo_appo():
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
start_episode = 1
|
start_episode = 1
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "appo", timestamp)
|
"appo",
|
||||||
log_dir = os.path.join("logs", "appo", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -246,4 +251,10 @@ def train_sumo_appo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_appo()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_appo(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
基于 SUMO+TraCI 的 DDPG 训练脚本
|
基于 SUMO+TraCI 的 DDPG 训练脚本
|
||||||
使用 Stable-Baselines3 的 DDPG 算法
|
使用 Stable-Baselines3 的 DDPG 算法
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -16,9 +17,10 @@ from agents.ddpg_agent import DDPGAgent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_ddpg():
|
def train_sumo_ddpg(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
"""SUMO 环境下的 DDPG 训练主函数"""
|
"""SUMO 环境下的 DDPG 训练主函数"""
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
@ -26,9 +28,12 @@ def train_sumo_ddpg():
|
||||||
agent_config = get_agent_config(config, "ddpg")
|
agent_config = get_agent_config(config, "ddpg")
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "ddpg", timestamp)
|
"ddpg",
|
||||||
log_dir = os.path.join("logs", "ddpg", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -183,4 +188,10 @@ def train_sumo_ddpg():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_ddpg()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_ddpg(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
DQN训练脚本 - SUMO VSL环境
|
DQN训练脚本 - SUMO VSL环境
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -17,9 +18,10 @@ from agents.dqn_agent import DQNAgent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_dqn():
|
def train_sumo_dqn(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
@ -27,9 +29,12 @@ def train_sumo_dqn():
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
start_episode = 1
|
start_episode = 1
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "dqn", timestamp)
|
"dqn",
|
||||||
log_dir = os.path.join("logs", "dqn", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -203,4 +208,10 @@ def train_sumo_dqn():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_dqn()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_dqn(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
MAPPO training script for SUMO + TraCI VSL.
|
MAPPO training script for SUMO + TraCI VSL.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -16,18 +17,22 @@ from agents.mappo_agent import MAPPOAgent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_mappo():
|
def train_sumo_mappo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
agent_config = get_agent_config(config, "mappo")
|
agent_config = get_agent_config(config, "mappo")
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "mappo", timestamp)
|
"mappo",
|
||||||
log_dir = os.path.join("logs", "mappo", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -235,4 +240,10 @@ def train_sumo_mappo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_mappo()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_mappo(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
基于 SUMO+TraCI 的 PPO 训练脚本
|
基于 SUMO+TraCI 的 PPO 训练脚本
|
||||||
使用微观仿真环境训练 VSL 控制策略
|
使用微观仿真环境训练 VSL 控制策略
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
|
|
@ -19,9 +20,10 @@ from agents.ppo_agent import PPOAgent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_ppo():
|
def train_sumo_ppo(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
"""SUMO 环境下的 PPO 训练主函数"""
|
"""SUMO 环境下的 PPO 训练主函数"""
|
||||||
# 加载配置
|
# 加载配置
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
|
|
@ -31,9 +33,12 @@ def train_sumo_ppo():
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
start_episode = 1
|
start_episode = 1
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "ppo", timestamp)
|
"ppo",
|
||||||
log_dir = os.path.join("logs", "ppo", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -249,4 +254,10 @@ def train_sumo_ppo():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_ppo()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_ppo(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
基于 SUMO+TraCI 的 TD3 训练脚本
|
基于 SUMO+TraCI 的 TD3 训练脚本
|
||||||
使用 Stable-Baselines3 的 TD3 算法
|
使用 Stable-Baselines3 的 TD3 算法
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -17,9 +18,10 @@ from agents.td3_agent import TD3Agent
|
||||||
from utils.config import get_agent_config, get_training_config
|
from utils.config import get_agent_config, get_training_config
|
||||||
from utils.logger import TrainingLogger
|
from utils.logger import TrainingLogger
|
||||||
from utils.plot import plot_training_curves
|
from utils.plot import plot_training_curves
|
||||||
|
from utils.run_dirs import add_run_dir_args, resolve_run_dirs
|
||||||
|
|
||||||
|
|
||||||
def train_sumo_td3():
|
def train_sumo_td3(log_dir=None, checkpoint_dir=None, run_timestamp=None):
|
||||||
"""SUMO 环境下的 TD3 训练主函数"""
|
"""SUMO 环境下的 TD3 训练主函数"""
|
||||||
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
with open("config_sumo_vsl.yaml", "r", encoding="utf-8") as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
|
|
@ -27,9 +29,12 @@ def train_sumo_td3():
|
||||||
agent_config = get_agent_config(config, "td3")
|
agent_config = get_agent_config(config, "td3")
|
||||||
train_config = get_training_config(config)
|
train_config = get_training_config(config)
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
_, checkpoint_dir, log_dir = resolve_run_dirs(
|
||||||
checkpoint_dir = os.path.join("checkpoints", "td3", timestamp)
|
"td3",
|
||||||
log_dir = os.path.join("logs", "td3", timestamp)
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
run_timestamp=run_timestamp,
|
||||||
|
)
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
runtime_config = copy.deepcopy(config)
|
runtime_config = copy.deepcopy(config)
|
||||||
|
|
@ -193,4 +198,10 @@ def train_sumo_td3():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_sumo_td3()
|
parser = add_run_dir_args(argparse.ArgumentParser())
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_sumo_td3(
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
checkpoint_dir=args.checkpoint_dir,
|
||||||
|
run_timestamp=args.run_timestamp,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def add_run_dir_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||||
|
parser.add_argument("--log-dir", type=str, default=None, help="Training log output directory.")
|
||||||
|
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Model checkpoint output directory.")
|
||||||
|
parser.add_argument("--run-timestamp", type=str, default=None, help="Run timestamp tag for default directories.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_run_dirs(
|
||||||
|
model_name: str,
|
||||||
|
log_dir: Optional[str] = None,
|
||||||
|
checkpoint_dir: Optional[str] = None,
|
||||||
|
run_timestamp: Optional[str] = None,
|
||||||
|
) -> Tuple[str, str, str]:
|
||||||
|
"""Resolve output directories from runtime args, falling back to per-model defaults."""
|
||||||
|
timestamp = run_timestamp or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = os.path.join("checkpoints", model_name, timestamp)
|
||||||
|
if log_dir is None:
|
||||||
|
log_dir = os.path.join("logs", model_name, timestamp)
|
||||||
|
|
||||||
|
return timestamp, checkpoint_dir, log_dir
|
||||||
Loading…
Reference in New Issue