58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
"""
|
|
Main entry point for CTM-DQN speed limit control system.
|
|
"""
|
|
import argparse
|
|
import os
|
|
from train import train
|
|
from test import test
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="DQN-based Dynamic Speed Limit Control with CTM"
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
type=str,
|
|
choices=["train", "test"],
|
|
default="train",
|
|
help="Mode: train or test",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="config.yaml",
|
|
help="Path to configuration file",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="latest/checkpoints/model_best.pt",
|
|
help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "train":
|
|
print("Starting training mode...")
|
|
train(args.config)
|
|
elif args.mode == "test":
|
|
print("Starting testing mode...")
|
|
|
|
# Auto-detect config from model path if using default model
|
|
config_path = args.config
|
|
model_path = args.model
|
|
|
|
# If using latest model and default config, use latest config instead
|
|
if model_path.startswith("latest/") and args.config == "config.yaml":
|
|
latest_config = os.path.join("latest", "config.yaml")
|
|
if os.path.exists(latest_config):
|
|
config_path = latest_config
|
|
print(f"Auto-detected config from latest run: {config_path}")
|
|
|
|
test(config_path, model_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|