训练后自动进行测试,自动删去不必要的文件

This commit is contained in:
Zihan Ye 2026-01-06 13:57:39 +08:00
parent 8a2194039c
commit e8f32f9942
5 changed files with 79 additions and 10 deletions

4
.gitignore vendored
View File

@ -14,6 +14,7 @@ uv.lock
*.pt
*.pth
runs/
latest/
# IDEs
.vscode/
@ -23,3 +24,6 @@ runs/
# OS
.DS_Store
Thumbs.db
# AI
.claude/

19
main.py
View File

@ -2,6 +2,7 @@
Main entry point for CTM-DQN speed limit control system.
"""
import argparse
import os
from train import train
from test import test
@ -26,8 +27,8 @@ def main():
parser.add_argument(
"--model",
type=str,
default=None,
help="Path to model checkpoint (for testing)",
default="latest/checkpoints/model_best.pt",
help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)",
)
args = parser.parse_args()
@ -37,7 +38,19 @@ def main():
train(args.config)
elif args.mode == "test":
print("Starting testing mode...")
test(args.config, args.model)
# Auto-detect config from model path if using default model
config_path = args.config
model_path = args.model
# If using latest model and default config, use latest config instead
if model_path.startswith("latest/") and args.config == "config.yaml":
latest_config = os.path.join("latest", "config.yaml")
if os.path.exists(latest_config):
config_path = latest_config
print(f"Auto-detected config from latest run: {config_path}")
test(config_path, model_path)
if __name__ == "__main__":

14
test.py
View File

@ -9,8 +9,14 @@ from environment import TrafficEnvironment
from dqn_agent import DQNAgent
def test(config_path: str = "config.yaml", model_path: str = None):
"""Test trained DQN agent."""
def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None):
"""Test trained DQN agent.
Args:
config_path: Path to configuration file
model_path: Path to model checkpoint
output_dir: Optional output directory for test results (overrides config)
"""
config = load_config(config_path)
create_directories(config)
@ -86,7 +92,9 @@ def test(config_path: str = "config.yaml", model_path: str = None):
print("="*50)
if render and len(all_densities) > 0:
plot_test_results(all_densities, all_actions, env, config["training"]["log_dir"])
# Use output_dir if specified, otherwise use config log_dir
log_dir = output_dir if output_dir else config["training"]["log_dir"]
plot_test_results(all_densities, all_actions, env, log_dir)
def plot_test_results(densities, actions, env, log_dir):

View File

@ -11,6 +11,7 @@ from utils import load_config, create_run_directory
from environment import TrafficEnvironment
from parallel_env import ParallelEnvironment
from dqn_agent import DQNAgent
from test import test
def set_random_seed(seed: int):
@ -199,6 +200,13 @@ def train(config_path: str = "config.yaml"):
episode_rewards, episode_losses, episode_throughputs, log_dir
)
# Auto-test the best model
print("\n" + "="*60)
print("Starting automatic testing with best model...")
print("="*60)
config_copy_path = os.path.join(run_dir, "config.yaml")
test(config_path=config_copy_path, model_path=best_model_path, output_dir=log_dir)
def plot_training_results(rewards, losses, throughputs, log_dir):
"""Plot and save training results."""

View File

@ -31,6 +31,7 @@ def create_directories(config: Dict):
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
"""
Create a timestamped run directory for this training session.
Latest run is saved to 'latest/', previous runs are moved to 'runs/'.
Args:
config: Configuration dictionary
@ -42,8 +43,37 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create run directory
run_dir = os.path.join("runs", f"run_{timestamp}")
# Check if 'latest' directory exists
latest_dir = "latest"
if os.path.exists(latest_dir):
# Move existing 'latest' to 'runs' with timestamp from its config
try:
# Try to read timestamp from existing latest directory
old_config_path = os.path.join(latest_dir, "config.yaml")
if os.path.exists(old_config_path):
# Get modification time of the old config
mtime = os.path.getmtime(old_config_path)
old_timestamp = datetime.fromtimestamp(mtime).strftime("%Y%m%d_%H%M%S")
else:
old_timestamp = "unknown"
# Move to runs directory
os.makedirs("runs", exist_ok=True)
archive_dir = os.path.join("runs", f"run_{old_timestamp}")
shutil.move(latest_dir, archive_dir)
# Clean up timestamp files in archived directory
for file in os.listdir(archive_dir):
if file.startswith("timestamp_") and file.endswith(".txt"):
timestamp_file_path = os.path.join(archive_dir, file)
os.remove(timestamp_file_path)
print(f"Archived previous run to: {archive_dir}")
except Exception as e:
print(f"Warning: Could not archive previous run: {e}")
# Create new 'latest' directory
run_dir = latest_dir
os.makedirs(run_dir, exist_ok=True)
# Create subdirectories
@ -52,11 +82,17 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
# Copy config file to run directory
# Copy config file to run directory with timestamp
config_copy_path = os.path.join(run_dir, "config.yaml")
shutil.copy(config_path, config_copy_path)
# Also save a timestamped copy for reference
timestamp_file = os.path.join(run_dir, f"timestamp_{timestamp}.txt")
with open(timestamp_file, "w") as f:
f.write(f"Training started at: {timestamp}\n")
print(f"Created run directory: {run_dir}")
print(f"Training timestamp: {timestamp}")
print(f"Config saved to: {config_copy_path}")
return run_dir, checkpoint_dir, log_dir