ctm-dqn/test.py

136 lines
4.1 KiB
Python

"""
Testing script for trained DQN agent.
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from utils import load_config, create_directories
from environment import TrafficEnvironment
from dqn_agent import DQNAgent
def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None):
"""Test trained DQN agent.
Args:
config_path: Path to configuration file
model_path: Path to model checkpoint
output_dir: Optional output directory for test results (overrides config)
"""
config = load_config(config_path)
create_directories(config)
env = TrafficEnvironment(config)
agent = DQNAgent(
state_dim=env.state_dim,
action_dim=env.action_dim,
hidden_layers=config["agent"]["hidden_layers"],
learning_rate=config["agent"]["learning_rate"],
gamma=config["agent"]["gamma"],
device=config["agent"]["device"],
)
if model_path is None:
model_path = os.path.join(
config["training"]["checkpoint_dir"], "model_final.pt"
)
if not os.path.exists(model_path):
print(f"Model not found at {model_path}")
return
agent.load(model_path)
print(f"Loaded model from {model_path}")
num_episodes = config["testing"]["num_episodes"]
render = config["testing"]["render"]
episode_rewards = []
episode_throughputs = []
all_densities = []
all_actions = []
print(f"Testing for {num_episodes} episodes...")
for episode in range(num_episodes):
state = env.reset()
episode_reward = 0
episode_densities = []
episode_actions = []
while True:
action = agent.select_action(state, training=False)
next_state, reward, done, info = env.step(action)
episode_reward += reward
episode_densities.append(info["densities"].copy())
episode_actions.append(action)
state = next_state
if done:
break
episode_rewards.append(episode_reward)
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
episode_throughputs.append(avg_throughput)
if episode == 0:
all_densities = episode_densities
all_actions = episode_actions
print(
f"Episode {episode + 1}/{num_episodes} | "
f"Reward: {episode_reward:.2f} | "
f"Avg Throughput: {avg_throughput:.2f}"
)
print("\n" + "="*50)
print("Testing Summary:")
print(f"Average Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
print(f"Average Throughput: {np.mean(episode_throughputs):.2f} ± {np.std(episode_throughputs):.2f}")
print("="*50)
if render and len(all_densities) > 0:
# Use output_dir if specified, otherwise use config log_dir
log_dir = output_dir if output_dir else config["training"]["log_dir"]
plot_test_results(all_densities, all_actions, env, log_dir)
def plot_test_results(densities, actions, env, log_dir):
"""Plot test results with density heatmap and actions."""
densities_array = np.array(densities)
actions_array = np.array(actions)
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
im = axes[0].imshow(
densities_array.T,
aspect='auto',
cmap='YlOrRd',
interpolation='nearest'
)
axes[0].set_xlabel("Time Step")
axes[0].set_ylabel("Cell Index")
axes[0].set_title("Traffic Density Heatmap (vehicles/km)")
plt.colorbar(im, ax=axes[0])
time_steps = range(len(actions_array))
speed_limits = [env.speed_actions[a] for a in actions_array]
axes[1].plot(time_steps, speed_limits, 'b-', linewidth=2)
axes[1].set_xlabel("Time Step")
axes[1].set_ylabel("Speed Limit (m/s)")
axes[1].set_title("Speed Limit Control Actions")
axes[1].grid(True)
axes[1].set_ylim([env.min_speed_limit - 2, env.max_speed_limit + 2])
plt.tight_layout()
plot_path = os.path.join(log_dir, "test_results.png")
plt.savefig(plot_path)
print(f"\nTest visualization saved to {plot_path}")
plt.close()
if __name__ == "__main__":
test()