diff --git a/.gitignore b/.gitignore index 5e0e862..78e2f63 100644 --- a/.gitignore +++ b/.gitignore @@ -11,10 +11,9 @@ wheels/ uv.lock # Project specific - CTM-DQN -checkpoints/ -logs/ *.pt *.pth +runs/ # IDEs .vscode/ diff --git a/train.py b/train.py index cd341ef..fe6daab 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ import numpy as np import torch from tqdm import tqdm import matplotlib.pyplot as plt -from utils import load_config, create_directories +from utils import load_config, create_run_directory from environment import TrafficEnvironment from parallel_env import ParallelEnvironment from dqn_agent import DQNAgent @@ -28,7 +28,9 @@ def set_random_seed(seed: int): def train(config_path: str = "config.yaml"): """Train DQN agent.""" config = load_config(config_path) - create_directories(config) + + # Create timestamped run directory + run_dir, checkpoint_dir, log_dir = create_run_directory(config, config_path) # Set random seed for reproducibility random_seed = config["training"].get("random_seed", 42) @@ -61,8 +63,8 @@ def train(config_path: str = "config.yaml"): num_episodes = config["training"]["num_episodes"] save_freq = config["training"]["save_freq"] log_freq = config["training"]["log_freq"] - checkpoint_dir = config["training"]["checkpoint_dir"] train_freq = config["training"].get("train_freq", 1) # Train every N steps + # checkpoint_dir and log_dir are already set from create_run_directory episode_rewards = [] episode_losses = [] @@ -194,7 +196,7 @@ def train(config_path: str = "config.yaml"): print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})") plot_training_results( - episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"] + episode_rewards, episode_losses, episode_throughputs, log_dir ) diff --git a/utils.py b/utils.py index da4ca99..d093ffb 100644 --- a/utils.py +++ b/utils.py @@ -3,7 +3,9 @@ Configuration utilities for loading and managing settings. """ import yaml import os -from typing import Dict +import shutil +from datetime import datetime +from typing import Dict, Tuple def load_config(config_path: str = "config.yaml") -> Dict: @@ -24,3 +26,37 @@ def create_directories(config: Dict): os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) + + +def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]: + """ + Create a timestamped run directory for this training session. + + Args: + config: Configuration dictionary + config_path: Path to the config file + + Returns: + Tuple of (run_dir, checkpoint_dir, log_dir) + """ + # Create timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create run directory + run_dir = os.path.join("runs", f"run_{timestamp}") + os.makedirs(run_dir, exist_ok=True) + + # Create subdirectories + checkpoint_dir = os.path.join(run_dir, "checkpoints") + log_dir = os.path.join(run_dir, "logs") + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) + + # Copy config file to run directory + config_copy_path = os.path.join(run_dir, "config.yaml") + shutil.copy(config_path, config_copy_path) + + print(f"Created run directory: {run_dir}") + print(f"Config saved to: {config_copy_path}") + + return run_dir, checkpoint_dir, log_dir