将每次运行参数及结果独立保存

This commit is contained in:
Zihan Ye 2026-01-05 21:43:04 +08:00
parent 2cecf7804f
commit 8a2194039c
3 changed files with 44 additions and 7 deletions

3
.gitignore vendored
View File

@ -11,10 +11,9 @@ wheels/
uv.lock
# Project specific - CTM-DQN
checkpoints/
logs/
*.pt
*.pth
runs/
# IDEs
.vscode/

View File

@ -7,7 +7,7 @@ import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import load_config, create_directories
from utils import load_config, create_run_directory
from environment import TrafficEnvironment
from parallel_env import ParallelEnvironment
from dqn_agent import DQNAgent
@ -28,7 +28,9 @@ def set_random_seed(seed: int):
def train(config_path: str = "config.yaml"):
"""Train DQN agent."""
config = load_config(config_path)
create_directories(config)
# Create timestamped run directory
run_dir, checkpoint_dir, log_dir = create_run_directory(config, config_path)
# Set random seed for reproducibility
random_seed = config["training"].get("random_seed", 42)
@ -61,8 +63,8 @@ def train(config_path: str = "config.yaml"):
num_episodes = config["training"]["num_episodes"]
save_freq = config["training"]["save_freq"]
log_freq = config["training"]["log_freq"]
checkpoint_dir = config["training"]["checkpoint_dir"]
train_freq = config["training"].get("train_freq", 1) # Train every N steps
# checkpoint_dir and log_dir are already set from create_run_directory
episode_rewards = []
episode_losses = []
@ -194,7 +196,7 @@ def train(config_path: str = "config.yaml"):
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
plot_training_results(
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
episode_rewards, episode_losses, episode_throughputs, log_dir
)

View File

@ -3,7 +3,9 @@ Configuration utilities for loading and managing settings.
"""
import yaml
import os
from typing import Dict
import shutil
from datetime import datetime
from typing import Dict, Tuple
def load_config(config_path: str = "config.yaml") -> Dict:
@ -24,3 +26,37 @@ def create_directories(config: Dict):
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
"""
Create a timestamped run directory for this training session.
Args:
config: Configuration dictionary
config_path: Path to the config file
Returns:
Tuple of (run_dir, checkpoint_dir, log_dir)
"""
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create run directory
run_dir = os.path.join("runs", f"run_{timestamp}")
os.makedirs(run_dir, exist_ok=True)
# Create subdirectories
checkpoint_dir = os.path.join(run_dir, "checkpoints")
log_dir = os.path.join(run_dir, "logs")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
# Copy config file to run directory
config_copy_path = os.path.join(run_dir, "config.yaml")
shutil.copy(config_path, config_copy_path)
print(f"Created run directory: {run_dir}")
print(f"Config saved to: {config_copy_path}")
return run_dir, checkpoint_dir, log_dir