将每次运行参数及结果独立保存

This commit is contained in:
Zihan Ye 2026-01-05 21:43:04 +08:00
parent 2cecf7804f
commit 8a2194039c
3 changed files with 44 additions and 7 deletions

3
.gitignore vendored
View File

@ -11,10 +11,9 @@ wheels/
uv.lock uv.lock
# Project specific - CTM-DQN # Project specific - CTM-DQN
checkpoints/
logs/
*.pt *.pt
*.pth *.pth
runs/
# IDEs # IDEs
.vscode/ .vscode/

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from utils import load_config, create_directories from utils import load_config, create_run_directory
from environment import TrafficEnvironment from environment import TrafficEnvironment
from parallel_env import ParallelEnvironment from parallel_env import ParallelEnvironment
from dqn_agent import DQNAgent from dqn_agent import DQNAgent
@ -28,7 +28,9 @@ def set_random_seed(seed: int):
def train(config_path: str = "config.yaml"): def train(config_path: str = "config.yaml"):
"""Train DQN agent.""" """Train DQN agent."""
config = load_config(config_path) config = load_config(config_path)
create_directories(config)
# Create timestamped run directory
run_dir, checkpoint_dir, log_dir = create_run_directory(config, config_path)
# Set random seed for reproducibility # Set random seed for reproducibility
random_seed = config["training"].get("random_seed", 42) random_seed = config["training"].get("random_seed", 42)
@ -61,8 +63,8 @@ def train(config_path: str = "config.yaml"):
num_episodes = config["training"]["num_episodes"] num_episodes = config["training"]["num_episodes"]
save_freq = config["training"]["save_freq"] save_freq = config["training"]["save_freq"]
log_freq = config["training"]["log_freq"] log_freq = config["training"]["log_freq"]
checkpoint_dir = config["training"]["checkpoint_dir"]
train_freq = config["training"].get("train_freq", 1) # Train every N steps train_freq = config["training"].get("train_freq", 1) # Train every N steps
# checkpoint_dir and log_dir are already set from create_run_directory
episode_rewards = [] episode_rewards = []
episode_losses = [] episode_losses = []
@ -194,7 +196,7 @@ def train(config_path: str = "config.yaml"):
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})") print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
plot_training_results( plot_training_results(
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"] episode_rewards, episode_losses, episode_throughputs, log_dir
) )

View File

@ -3,7 +3,9 @@ Configuration utilities for loading and managing settings.
""" """
import yaml import yaml
import os import os
from typing import Dict import shutil
from datetime import datetime
from typing import Dict, Tuple
def load_config(config_path: str = "config.yaml") -> Dict: def load_config(config_path: str = "config.yaml") -> Dict:
@ -24,3 +26,37 @@ def create_directories(config: Dict):
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
"""
Create a timestamped run directory for this training session.
Args:
config: Configuration dictionary
config_path: Path to the config file
Returns:
Tuple of (run_dir, checkpoint_dir, log_dir)
"""
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create run directory
run_dir = os.path.join("runs", f"run_{timestamp}")
os.makedirs(run_dir, exist_ok=True)
# Create subdirectories
checkpoint_dir = os.path.join(run_dir, "checkpoints")
log_dir = os.path.join(run_dir, "logs")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
# Copy config file to run directory
config_copy_path = os.path.join(run_dir, "config.yaml")
shutil.copy(config_path, config_copy_path)
print(f"Created run directory: {run_dir}")
print(f"Config saved to: {config_copy_path}")
return run_dir, checkpoint_dir, log_dir