""" Testing script for trained DQN agent. """ import os import numpy as np import matplotlib.pyplot as plt from utils import load_config, create_directories from environment import TrafficEnvironment from dqn_agent import DQNAgent def test(config_path: str = "config.yaml", model_path: str = None): """Test trained DQN agent.""" config = load_config(config_path) create_directories(config) env = TrafficEnvironment(config) agent = DQNAgent( state_dim=env.state_dim, action_dim=env.action_dim, hidden_layers=config["agent"]["hidden_layers"], learning_rate=config["agent"]["learning_rate"], gamma=config["agent"]["gamma"], device=config["agent"]["device"], ) if model_path is None: model_path = os.path.join( config["training"]["checkpoint_dir"], "model_final.pt" ) if not os.path.exists(model_path): print(f"Model not found at {model_path}") return agent.load(model_path) print(f"Loaded model from {model_path}") num_episodes = config["testing"]["num_episodes"] render = config["testing"]["render"] episode_rewards = [] episode_throughputs = [] all_densities = [] all_actions = [] print(f"Testing for {num_episodes} episodes...") for episode in range(num_episodes): state = env.reset() episode_reward = 0 episode_densities = [] episode_actions = [] while True: action = agent.select_action(state, training=False) next_state, reward, done, info = env.step(action) episode_reward += reward episode_densities.append(info["densities"].copy()) episode_actions.append(action) state = next_state if done: break episode_rewards.append(episode_reward) avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics]) episode_throughputs.append(avg_throughput) if episode == 0: all_densities = episode_densities all_actions = episode_actions print( f"Episode {episode + 1}/{num_episodes} | " f"Reward: {episode_reward:.2f} | " f"Avg Throughput: {avg_throughput:.2f}" ) print("\n" + "="*50) print("Testing Summary:") print(f"Average Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}") print(f"Average Throughput: {np.mean(episode_throughputs):.2f} ± {np.std(episode_throughputs):.2f}") print("="*50) if render and len(all_densities) > 0: plot_test_results(all_densities, all_actions, env, config["training"]["log_dir"]) def plot_test_results(densities, actions, env, log_dir): """Plot test results with density heatmap and actions.""" densities_array = np.array(densities) actions_array = np.array(actions) fig, axes = plt.subplots(2, 1, figsize=(12, 8)) im = axes[0].imshow( densities_array.T, aspect='auto', cmap='YlOrRd', interpolation='nearest' ) axes[0].set_xlabel("Time Step") axes[0].set_ylabel("Cell Index") axes[0].set_title("Traffic Density Heatmap (vehicles/km)") plt.colorbar(im, ax=axes[0]) time_steps = range(len(actions_array)) speed_limits = [env.speed_actions[a] for a in actions_array] axes[1].plot(time_steps, speed_limits, 'b-', linewidth=2) axes[1].set_xlabel("Time Step") axes[1].set_ylabel("Speed Limit (m/s)") axes[1].set_title("Speed Limit Control Actions") axes[1].grid(True) axes[1].set_ylim([env.min_speed_limit - 2, env.max_speed_limit + 2]) plt.tight_layout() plot_path = os.path.join(log_dir, "test_results.png") plt.savefig(plot_path) print(f"\nTest visualization saved to {plot_path}") plt.close() if __name__ == "__main__": test()