训练后自动进行测试,自动删去不必要的文件
This commit is contained in:
parent
8a2194039c
commit
e8f32f9942
|
|
@ -14,6 +14,7 @@ uv.lock
|
|||
*.pt
|
||||
*.pth
|
||||
runs/
|
||||
latest/
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
|
|
@ -22,4 +23,7 @@ runs/
|
|||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# AI
|
||||
.claude/
|
||||
19
main.py
19
main.py
|
|
@ -2,6 +2,7 @@
|
|||
Main entry point for CTM-DQN speed limit control system.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
from train import train
|
||||
from test import test
|
||||
|
||||
|
|
@ -26,8 +27,8 @@ def main():
|
|||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to model checkpoint (for testing)",
|
||||
default="latest/checkpoints/model_best.pt",
|
||||
help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
@ -37,7 +38,19 @@ def main():
|
|||
train(args.config)
|
||||
elif args.mode == "test":
|
||||
print("Starting testing mode...")
|
||||
test(args.config, args.model)
|
||||
|
||||
# Auto-detect config from model path if using default model
|
||||
config_path = args.config
|
||||
model_path = args.model
|
||||
|
||||
# If using latest model and default config, use latest config instead
|
||||
if model_path.startswith("latest/") and args.config == "config.yaml":
|
||||
latest_config = os.path.join("latest", "config.yaml")
|
||||
if os.path.exists(latest_config):
|
||||
config_path = latest_config
|
||||
print(f"Auto-detected config from latest run: {config_path}")
|
||||
|
||||
test(config_path, model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
14
test.py
14
test.py
|
|
@ -9,8 +9,14 @@ from environment import TrafficEnvironment
|
|||
from dqn_agent import DQNAgent
|
||||
|
||||
|
||||
def test(config_path: str = "config.yaml", model_path: str = None):
|
||||
"""Test trained DQN agent."""
|
||||
def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None):
|
||||
"""Test trained DQN agent.
|
||||
|
||||
Args:
|
||||
config_path: Path to configuration file
|
||||
model_path: Path to model checkpoint
|
||||
output_dir: Optional output directory for test results (overrides config)
|
||||
"""
|
||||
config = load_config(config_path)
|
||||
create_directories(config)
|
||||
|
||||
|
|
@ -86,7 +92,9 @@ def test(config_path: str = "config.yaml", model_path: str = None):
|
|||
print("="*50)
|
||||
|
||||
if render and len(all_densities) > 0:
|
||||
plot_test_results(all_densities, all_actions, env, config["training"]["log_dir"])
|
||||
# Use output_dir if specified, otherwise use config log_dir
|
||||
log_dir = output_dir if output_dir else config["training"]["log_dir"]
|
||||
plot_test_results(all_densities, all_actions, env, log_dir)
|
||||
|
||||
|
||||
def plot_test_results(densities, actions, env, log_dir):
|
||||
|
|
|
|||
8
train.py
8
train.py
|
|
@ -11,6 +11,7 @@ from utils import load_config, create_run_directory
|
|||
from environment import TrafficEnvironment
|
||||
from parallel_env import ParallelEnvironment
|
||||
from dqn_agent import DQNAgent
|
||||
from test import test
|
||||
|
||||
|
||||
def set_random_seed(seed: int):
|
||||
|
|
@ -199,6 +200,13 @@ def train(config_path: str = "config.yaml"):
|
|||
episode_rewards, episode_losses, episode_throughputs, log_dir
|
||||
)
|
||||
|
||||
# Auto-test the best model
|
||||
print("\n" + "="*60)
|
||||
print("Starting automatic testing with best model...")
|
||||
print("="*60)
|
||||
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||
test(config_path=config_copy_path, model_path=best_model_path, output_dir=log_dir)
|
||||
|
||||
|
||||
def plot_training_results(rewards, losses, throughputs, log_dir):
|
||||
"""Plot and save training results."""
|
||||
|
|
|
|||
42
utils.py
42
utils.py
|
|
@ -31,6 +31,7 @@ def create_directories(config: Dict):
|
|||
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
|
||||
"""
|
||||
Create a timestamped run directory for this training session.
|
||||
Latest run is saved to 'latest/', previous runs are moved to 'runs/'.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
|
@ -42,8 +43,37 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
|
|||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create run directory
|
||||
run_dir = os.path.join("runs", f"run_{timestamp}")
|
||||
# Check if 'latest' directory exists
|
||||
latest_dir = "latest"
|
||||
if os.path.exists(latest_dir):
|
||||
# Move existing 'latest' to 'runs' with timestamp from its config
|
||||
try:
|
||||
# Try to read timestamp from existing latest directory
|
||||
old_config_path = os.path.join(latest_dir, "config.yaml")
|
||||
if os.path.exists(old_config_path):
|
||||
# Get modification time of the old config
|
||||
mtime = os.path.getmtime(old_config_path)
|
||||
old_timestamp = datetime.fromtimestamp(mtime).strftime("%Y%m%d_%H%M%S")
|
||||
else:
|
||||
old_timestamp = "unknown"
|
||||
|
||||
# Move to runs directory
|
||||
os.makedirs("runs", exist_ok=True)
|
||||
archive_dir = os.path.join("runs", f"run_{old_timestamp}")
|
||||
shutil.move(latest_dir, archive_dir)
|
||||
|
||||
# Clean up timestamp files in archived directory
|
||||
for file in os.listdir(archive_dir):
|
||||
if file.startswith("timestamp_") and file.endswith(".txt"):
|
||||
timestamp_file_path = os.path.join(archive_dir, file)
|
||||
os.remove(timestamp_file_path)
|
||||
|
||||
print(f"Archived previous run to: {archive_dir}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not archive previous run: {e}")
|
||||
|
||||
# Create new 'latest' directory
|
||||
run_dir = latest_dir
|
||||
os.makedirs(run_dir, exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
|
|
@ -52,11 +82,17 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
|
|||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Copy config file to run directory
|
||||
# Copy config file to run directory with timestamp
|
||||
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||
shutil.copy(config_path, config_copy_path)
|
||||
|
||||
# Also save a timestamped copy for reference
|
||||
timestamp_file = os.path.join(run_dir, f"timestamp_{timestamp}.txt")
|
||||
with open(timestamp_file, "w") as f:
|
||||
f.write(f"Training started at: {timestamp}\n")
|
||||
|
||||
print(f"Created run directory: {run_dir}")
|
||||
print(f"Training timestamp: {timestamp}")
|
||||
print(f"Config saved to: {config_copy_path}")
|
||||
|
||||
return run_dir, checkpoint_dir, log_dir
|
||||
|
|
|
|||
Loading…
Reference in New Issue