将每次运行参数及结果独立保存
This commit is contained in:
parent
2cecf7804f
commit
8a2194039c
|
|
@ -11,10 +11,9 @@ wheels/
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
# Project specific - CTM-DQN
|
# Project specific - CTM-DQN
|
||||||
checkpoints/
|
|
||||||
logs/
|
|
||||||
*.pt
|
*.pt
|
||||||
*.pth
|
*.pth
|
||||||
|
runs/
|
||||||
|
|
||||||
# IDEs
|
# IDEs
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
|
||||||
10
train.py
10
train.py
|
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from utils import load_config, create_directories
|
from utils import load_config, create_run_directory
|
||||||
from environment import TrafficEnvironment
|
from environment import TrafficEnvironment
|
||||||
from parallel_env import ParallelEnvironment
|
from parallel_env import ParallelEnvironment
|
||||||
from dqn_agent import DQNAgent
|
from dqn_agent import DQNAgent
|
||||||
|
|
@ -28,7 +28,9 @@ def set_random_seed(seed: int):
|
||||||
def train(config_path: str = "config.yaml"):
|
def train(config_path: str = "config.yaml"):
|
||||||
"""Train DQN agent."""
|
"""Train DQN agent."""
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
create_directories(config)
|
|
||||||
|
# Create timestamped run directory
|
||||||
|
run_dir, checkpoint_dir, log_dir = create_run_directory(config, config_path)
|
||||||
|
|
||||||
# Set random seed for reproducibility
|
# Set random seed for reproducibility
|
||||||
random_seed = config["training"].get("random_seed", 42)
|
random_seed = config["training"].get("random_seed", 42)
|
||||||
|
|
@ -61,8 +63,8 @@ def train(config_path: str = "config.yaml"):
|
||||||
num_episodes = config["training"]["num_episodes"]
|
num_episodes = config["training"]["num_episodes"]
|
||||||
save_freq = config["training"]["save_freq"]
|
save_freq = config["training"]["save_freq"]
|
||||||
log_freq = config["training"]["log_freq"]
|
log_freq = config["training"]["log_freq"]
|
||||||
checkpoint_dir = config["training"]["checkpoint_dir"]
|
|
||||||
train_freq = config["training"].get("train_freq", 1) # Train every N steps
|
train_freq = config["training"].get("train_freq", 1) # Train every N steps
|
||||||
|
# checkpoint_dir and log_dir are already set from create_run_directory
|
||||||
|
|
||||||
episode_rewards = []
|
episode_rewards = []
|
||||||
episode_losses = []
|
episode_losses = []
|
||||||
|
|
@ -194,7 +196,7 @@ def train(config_path: str = "config.yaml"):
|
||||||
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
|
print(f"Best model saved to {best_model_path} (Best Reward: {best_reward:.2f})")
|
||||||
|
|
||||||
plot_training_results(
|
plot_training_results(
|
||||||
episode_rewards, episode_losses, episode_throughputs, config["training"]["log_dir"]
|
episode_rewards, episode_losses, episode_throughputs, log_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
38
utils.py
38
utils.py
|
|
@ -3,7 +3,9 @@ Configuration utilities for loading and managing settings.
|
||||||
"""
|
"""
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str = "config.yaml") -> Dict:
|
def load_config(config_path: str = "config.yaml") -> Dict:
|
||||||
|
|
@ -24,3 +26,37 @@ def create_directories(config: Dict):
|
||||||
|
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
Create a timestamped run directory for this training session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
config_path: Path to the config file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (run_dir, checkpoint_dir, log_dir)
|
||||||
|
"""
|
||||||
|
# Create timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
# Create run directory
|
||||||
|
run_dir = os.path.join("runs", f"run_{timestamp}")
|
||||||
|
os.makedirs(run_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create subdirectories
|
||||||
|
checkpoint_dir = os.path.join(run_dir, "checkpoints")
|
||||||
|
log_dir = os.path.join(run_dir, "logs")
|
||||||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy config file to run directory
|
||||||
|
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||||
|
shutil.copy(config_path, config_copy_path)
|
||||||
|
|
||||||
|
print(f"Created run directory: {run_dir}")
|
||||||
|
print(f"Config saved to: {config_copy_path}")
|
||||||
|
|
||||||
|
return run_dir, checkpoint_dir, log_dir
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue