ctm-dqn/main.py

58 lines
1.6 KiB
Python

"""
Main entry point for CTM-DQN speed limit control system.
"""
import argparse
import os
from train import train
from test import test
def main():
parser = argparse.ArgumentParser(
description="DQN-based Dynamic Speed Limit Control with CTM"
)
parser.add_argument(
"--mode",
type=str,
choices=["train", "test"],
default="train",
help="Mode: train or test",
)
parser.add_argument(
"--config",
type=str,
default="config.yaml",
help="Path to configuration file",
)
parser.add_argument(
"--model",
type=str,
default="latest/checkpoints/model_best.pt",
help="Path to model checkpoint (for testing, default: latest/checkpoints/model_best.pt)",
)
args = parser.parse_args()
if args.mode == "train":
print("Starting training mode...")
train(args.config)
elif args.mode == "test":
print("Starting testing mode...")
# Auto-detect config from model path if using default model
config_path = args.config
model_path = args.model
# If using latest model and default config, use latest config instead
if model_path.startswith("latest/") and args.config == "config.yaml":
latest_config = os.path.join("latest", "config.yaml")
if os.path.exists(latest_config):
config_path = latest_config
print(f"Auto-detected config from latest run: {config_path}")
test(config_path, model_path)
if __name__ == "__main__":
main()