将每次运行参数及结果独立保存
This commit is contained in:
parent
2cecf7804f
commit
8a2194039c
|
|
@ -11,10 +11,9 @@ wheels/
|
|||
uv.lock
|
||||
|
||||
# Project specific - CTM-DQN
|
||||
checkpoints/
|
||||
logs/
|
||||
*.pt
|
||||
*.pth
|
||||
runs/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
|
|
|
|||
10
train.py
10
train.py
|
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
from utils import load_config, create_directories
|
||||
from utils import load_config, create_run_directory
|
||||
from environment import TrafficEnvironment
|
||||
from parallel_env import ParallelEnvironment
|
||||
from dqn_agent import DQNAgent
|
||||
|
|
@ -28,7 +28,9 @@ def set_random_seed(seed: int):
|
|||
def train(config_path: str = "config.yaml"):
|
||||
"""Train DQN agent."""
|
||||
config = load_config(config_path)
|
||||
create_directories(config)
|
||||
|
||||
# Create timestamped run directory
|
||||
run_dir, checkpoint_dir, log_dir = create_run_directory(config, config_path)
|
||||
|
||||
# Set random seed for reproducibility
|
||||
random_seed = config["training"].get("random_seed", 42)
|
||||
|
|
@ -61,8 +63,8 @@ def train(config_path: str = "config.yaml"):
|
|||
num_episodes = config["training"]["num_episodes"]
|
||||
save_freq = config["training"]["save_freq"]
|
||||
log_freq = config["training"]["log_freq"]
|
||||
checkpoint_dir = config["training"]["checkpoint_dir"]
|
||||
train_freq = config["training"].get("train_freq", 1) # Train every N steps
|
||||
# checkpoint_dir and log_dir are already set from create_run_directory
|
||||
|
||||
episode_rewards = []
|
||||
episode_losses = []
|
||||
|
|
@ -194,7 +196,7 @@ def train(config_path: str = "config.yaml"):
|
|||
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
|
||||
|
||||
plot_training_results(
|
||||
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
|
||||
episode_rewards, episode_losses, episode_throughputs, log_dir
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
38
utils.py
38
utils.py
|
|
@ -3,7 +3,9 @@ Configuration utilities for loading and managing settings.
|
|||
"""
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
def load_config(config_path: str = "config.yaml") -> Dict:
|
||||
|
|
@ -24,3 +26,37 @@ def create_directories(config: Dict):
|
|||
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
|
||||
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
|
||||
"""
|
||||
Create a timestamped run directory for this training session.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
config_path: Path to the config file
|
||||
|
||||
Returns:
|
||||
Tuple of (run_dir, checkpoint_dir, log_dir)
|
||||
"""
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create run directory
|
||||
run_dir = os.path.join("runs", f"run_{timestamp}")
|
||||
os.makedirs(run_dir, exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
checkpoint_dir = os.path.join(run_dir, "checkpoints")
|
||||
log_dir = os.path.join(run_dir, "logs")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Copy config file to run directory
|
||||
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||
shutil.copy(config_path, config_copy_path)
|
||||
|
||||
print(f"Created run directory: {run_dir}")
|
||||
print(f"Config saved to: {config_copy_path}")
|
||||
|
||||
return run_dir, checkpoint_dir, log_dir
|
||||
|
|
|
|||
Loading…
Reference in New Issue