训练后自动进行测试,自动删去不必要的文件
This commit is contained in:
parent
8a2194039c
commit
e8f32f9942
|
|
@ -14,6 +14,7 @@ uv.lock
|
||||||
*.pt
|
*.pt
|
||||||
*.pth
|
*.pth
|
||||||
runs/
|
runs/
|
||||||
|
latest/
|
||||||
|
|
||||||
# IDEs
|
# IDEs
|
||||||
.vscode/
|
.vscode/
|
||||||
|
|
@ -22,4 +23,7 @@ runs/
|
||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
|
||||||
|
# AI
|
||||||
|
.claude/
|
||||||
19
main.py
19
main.py
|
|
@ -2,6 +2,7 @@
|
||||||
Main entry point for CTM-DQN speed limit control system.
|
Main entry point for CTM-DQN speed limit control system.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from train import train
|
from train import train
|
||||||
from test import test
|
from test import test
|
||||||
|
|
||||||
|
|
@ -26,8 +27,8 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default="latest/checkpoints/model_best.pt",
|
||||||
help="Path to model checkpoint (for testing)",
|
help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
@ -37,7 +38,19 @@ def main():
|
||||||
train(args.config)
|
train(args.config)
|
||||||
elif args.mode == "test":
|
elif args.mode == "test":
|
||||||
print("Starting testing mode...")
|
print("Starting testing mode...")
|
||||||
test(args.config, args.model)
|
|
||||||
|
# Auto-detect config from model path if using default model
|
||||||
|
config_path = args.config
|
||||||
|
model_path = args.model
|
||||||
|
|
||||||
|
# If using latest model and default config, use latest config instead
|
||||||
|
if model_path.startswith("latest/") and args.config == "config.yaml":
|
||||||
|
latest_config = os.path.join("latest", "config.yaml")
|
||||||
|
if os.path.exists(latest_config):
|
||||||
|
config_path = latest_config
|
||||||
|
print(f"Auto-detected config from latest run: {config_path}")
|
||||||
|
|
||||||
|
test(config_path, model_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
14
test.py
14
test.py
|
|
@ -9,8 +9,14 @@ from environment import TrafficEnvironment
|
||||||
from dqn_agent import DQNAgent
|
from dqn_agent import DQNAgent
|
||||||
|
|
||||||
|
|
||||||
def test(config_path: str = "config.yaml", model_path: str = None):
|
def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None):
|
||||||
"""Test trained DQN agent."""
|
"""Test trained DQN agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to configuration file
|
||||||
|
model_path: Path to model checkpoint
|
||||||
|
output_dir: Optional output directory for test results (overrides config)
|
||||||
|
"""
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
create_directories(config)
|
create_directories(config)
|
||||||
|
|
||||||
|
|
@ -86,7 +92,9 @@ def test(config_path: str = "config.yaml", model_path: str = None):
|
||||||
print("="*50)
|
print("="*50)
|
||||||
|
|
||||||
if render and len(all_densities) > 0:
|
if render and len(all_densities) > 0:
|
||||||
plot_test_results(all_densities, all_actions, env, config["training"]["log_dir"])
|
# Use output_dir if specified, otherwise use config log_dir
|
||||||
|
log_dir = output_dir if output_dir else config["training"]["log_dir"]
|
||||||
|
plot_test_results(all_densities, all_actions, env, log_dir)
|
||||||
|
|
||||||
|
|
||||||
def plot_test_results(densities, actions, env, log_dir):
|
def plot_test_results(densities, actions, env, log_dir):
|
||||||
|
|
|
||||||
8
train.py
8
train.py
|
|
@ -11,6 +11,7 @@ from utils import load_config, create_run_directory
|
||||||
from environment import TrafficEnvironment
|
from environment import TrafficEnvironment
|
||||||
from parallel_env import ParallelEnvironment
|
from parallel_env import ParallelEnvironment
|
||||||
from dqn_agent import DQNAgent
|
from dqn_agent import DQNAgent
|
||||||
|
from test import test
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int):
|
def set_random_seed(seed: int):
|
||||||
|
|
@ -199,6 +200,13 @@ def train(config_path: str = "config.yaml"):
|
||||||
episode_rewards, episode_losses, episode_throughputs, log_dir
|
episode_rewards, episode_losses, episode_throughputs, log_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auto-test the best model
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("Starting automatic testing with best model...")
|
||||||
|
print("="*60)
|
||||||
|
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||||
|
test(config_path=config_copy_path, model_path=best_model_path, output_dir=log_dir)
|
||||||
|
|
||||||
|
|
||||||
def plot_training_results(rewards, losses, throughputs, log_dir):
|
def plot_training_results(rewards, losses, throughputs, log_dir):
|
||||||
"""Plot and save training results."""
|
"""Plot and save training results."""
|
||||||
|
|
|
||||||
42
utils.py
42
utils.py
|
|
@ -31,6 +31,7 @@ def create_directories(config: Dict):
|
||||||
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
|
def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Create a timestamped run directory for this training session.
|
Create a timestamped run directory for this training session.
|
||||||
|
Latest run is saved to 'latest/', previous runs are moved to 'runs/'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Configuration dictionary
|
config: Configuration dictionary
|
||||||
|
|
@ -42,8 +43,37 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
|
||||||
# Create timestamp
|
# Create timestamp
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
# Create run directory
|
# Check if 'latest' directory exists
|
||||||
run_dir = os.path.join("runs", f"run_{timestamp}")
|
latest_dir = "latest"
|
||||||
|
if os.path.exists(latest_dir):
|
||||||
|
# Move existing 'latest' to 'runs' with timestamp from its config
|
||||||
|
try:
|
||||||
|
# Try to read timestamp from existing latest directory
|
||||||
|
old_config_path = os.path.join(latest_dir, "config.yaml")
|
||||||
|
if os.path.exists(old_config_path):
|
||||||
|
# Get modification time of the old config
|
||||||
|
mtime = os.path.getmtime(old_config_path)
|
||||||
|
old_timestamp = datetime.fromtimestamp(mtime).strftime("%Y%m%d_%H%M%S")
|
||||||
|
else:
|
||||||
|
old_timestamp = "unknown"
|
||||||
|
|
||||||
|
# Move to runs directory
|
||||||
|
os.makedirs("runs", exist_ok=True)
|
||||||
|
archive_dir = os.path.join("runs", f"run_{old_timestamp}")
|
||||||
|
shutil.move(latest_dir, archive_dir)
|
||||||
|
|
||||||
|
# Clean up timestamp files in archived directory
|
||||||
|
for file in os.listdir(archive_dir):
|
||||||
|
if file.startswith("timestamp_") and file.endswith(".txt"):
|
||||||
|
timestamp_file_path = os.path.join(archive_dir, file)
|
||||||
|
os.remove(timestamp_file_path)
|
||||||
|
|
||||||
|
print(f"Archived previous run to: {archive_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not archive previous run: {e}")
|
||||||
|
|
||||||
|
# Create new 'latest' directory
|
||||||
|
run_dir = latest_dir
|
||||||
os.makedirs(run_dir, exist_ok=True)
|
os.makedirs(run_dir, exist_ok=True)
|
||||||
|
|
||||||
# Create subdirectories
|
# Create subdirectories
|
||||||
|
|
@ -52,11 +82,17 @@ def create_run_directory(config: Dict, config_path: str = "config.yaml") -> Tupl
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
# Copy config file to run directory
|
# Copy config file to run directory with timestamp
|
||||||
config_copy_path = os.path.join(run_dir, "config.yaml")
|
config_copy_path = os.path.join(run_dir, "config.yaml")
|
||||||
shutil.copy(config_path, config_copy_path)
|
shutil.copy(config_path, config_copy_path)
|
||||||
|
|
||||||
|
# Also save a timestamped copy for reference
|
||||||
|
timestamp_file = os.path.join(run_dir, f"timestamp_{timestamp}.txt")
|
||||||
|
with open(timestamp_file, "w") as f:
|
||||||
|
f.write(f"Training started at: {timestamp}\n")
|
||||||
|
|
||||||
print(f"Created run directory: {run_dir}")
|
print(f"Created run directory: {run_dir}")
|
||||||
|
print(f"Training timestamp: {timestamp}")
|
||||||
print(f"Config saved to: {config_copy_path}")
|
print(f"Config saved to: {config_copy_path}")
|
||||||
|
|
||||||
return run_dir, checkpoint_dir, log_dir
|
return run_dir, checkpoint_dir, log_dir
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue