diff --git a/.gitignore b/.gitignore index 78e2f63..16af6c9 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ uv.lock *.pt *.pth runs/ +latest/ # IDEs .vscode/ @@ -22,4 +23,7 @@ runs/ # OS .DS_Store -Thumbs.db \ No newline at end of file +Thumbs.db + +# AI +.claude/ \ No newline at end of file diff --git a/main.py b/main.py index dd08ed6..45ffced 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ Main entry point for CTM-DQN speed limit control system. """ import argparse +import os from train import train from test import test @@ -26,8 +27,8 @@ def main(): parser.add_argument( "--model", type=str, - default=None, - help="Path to model checkpoint (for testing)", + default="latest/checkpoints/model_best.pt", + help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)", ) args = parser.parse_args() @@ -37,7 +38,19 @@ def main(): train(args.config) elif args.mode == "test": print("Starting testing mode...") - test(args.config, args.model) + + # Auto-detect config from model path if using default model + config_path = args.config + model_path = args.model + + # If using latest model and default config, use latest config instead + if model_path.startswith("latest/") and args.config == "config.yaml": + latest_config = os.path.join("latest", "config.yaml") + if os.path.exists(latest_config): + config_path = latest_config + print(f"Auto-detected config from latest run: {config_path}") + + test(config_path, model_path) if __name__ == "__main__": diff --git a/test.py b/test.py index 426343a..2a4746c 100644 --- a/test.py +++ b/test.py @@ -9,8 +9,14 @@ from environment import TrafficEnvironment from dqn_agent import DQNAgent -def test(config_path: str = "config.yaml", model_path: str = None): - """Test trained DQN agent.""" +def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None): + """Test trained DQN agent. + + Args: + config_path: Path to configuration file + model_path: Path to model checkpoint + output_dir: Optional output directory for test results (overrides config) + """ config = load_config(config_path) create_directories(config) @@ -86,7 +92,9 @@ def test(config_path: str = "config.yaml", model_path: str = None): print("="*50) if render and len(all_densities) > 0: - plot_test_results(all_densities, all_actions, env, config["training"]["log_dir"]) + # Use output_dir if specified, otherwise use config log_dir + log_dir = output_dir if output_dir else config["training"]["log_dir"] + plot_test_results(all_densities, all_actions, env, log_dir) def plot_test_results(densities, actions, env, log_dir): diff --git a/train.py b/train.py index fe6daab..bc8a46b 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ from utils import load_config, create_run_directory from environment import TrafficEnvironment from parallel_env import ParallelEnvironment from dqn_agent import DQNAgent +from test import test def set_random_seed(seed: int): @@ -199,6 +200,13 @@ def train(config_path: str = "config.yaml"): episode_rewards, episode_losses, episode_throughputs, log_dir ) + # Auto-test the best model + print("\n" + "="*60) + print("Starting automatic testing with best model...") + print("="*60) + config_copy_path = os.path.join(run_dir, "config.yaml") + test(config_path=config_copy_path, model_path=best_model_path, output_dir=log_dir) + def plot_training_results(rewards, losses, throughputs, log_dir): """Plot and save training results.""" diff --git a/utils.py b/utils.py index d093ffb..460667a 100644 --- a/utils.py +++ b/utils.py @@ -31,6 +31,7 @@ def create_directories(config: Dict): def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]: """ Create a timestamped run directory for this training session. + Latest run is saved to 'latest/', previous runs are moved to 'runs/'. Args: config: Configuration dictionary @@ -42,8 +43,37 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl # Create timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - # Create run directory - run_dir = os.path.join("runs", f"run_{timestamp}") + # Check if 'latest' directory exists + latest_dir = "latest" + if os.path.exists(latest_dir): + # Move existing 'latest' to 'runs' with timestamp from its config + try: + # Try to read timestamp from existing latest directory + old_config_path = os.path.join(latest_dir, "config.yaml") + if os.path.exists(old_config_path): + # Get modification time of the old config + mtime = os.path.getmtime(old_config_path) + old_timestamp = datetime.fromtimestamp(mtime).strftime("%Y%m%d_%H%M%S") + else: + old_timestamp = "unknown" + + # Move to runs directory + os.makedirs("runs", exist_ok=True) + archive_dir = os.path.join("runs", f"run_{old_timestamp}") + shutil.move(latest_dir, archive_dir) + + # Clean up timestamp files in archived directory + for file in os.listdir(archive_dir): + if file.startswith("timestamp_") and file.endswith(".txt"): + timestamp_file_path = os.path.join(archive_dir, file) + os.remove(timestamp_file_path) + + print(f"Archived previous run to: {archive_dir}") + except Exception as e: + print(f"Warning: Could not archive previous run: {e}") + + # Create new 'latest' directory + run_dir = latest_dir os.makedirs(run_dir, exist_ok=True) # Create subdirectories @@ -52,11 +82,17 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True) - # Copy config file to run directory + # Copy config file to run directory with timestamp config_copy_path = os.path.join(run_dir, "config.yaml") shutil.copy(config_path, config_copy_path) + # Also save a timestamped copy for reference + timestamp_file = os.path.join(run_dir, f"timestamp_{timestamp}.txt") + with open(timestamp_file, "w") as f: + f.write(f"Training started at: {timestamp}\n") + print(f"Created run directory: {run_dir}") + print(f"Training timestamp: {timestamp}") print(f"Config saved to: {config_copy_path}") return run_dir, checkpoint_dir, log_dir