136 lines
4.1 KiB
Python
136 lines
4.1 KiB
Python
"""
|
|
Testing script for trained DQN agent.
|
|
"""
|
|
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from utils import load_config, create_directories
|
|
from environment import TrafficEnvironment
|
|
from dqn_agent import DQNAgent
|
|
|
|
|
|
def test(config_path: str = "config.yaml", model_path: str = None, output_dir: str = None):
|
|
"""Test trained DQN agent.
|
|
|
|
Args:
|
|
config_path: Path to configuration file
|
|
model_path: Path to model checkpoint
|
|
output_dir: Optional output directory for test results (overrides config)
|
|
"""
|
|
config = load_config(config_path)
|
|
create_directories(config)
|
|
|
|
env = TrafficEnvironment(config)
|
|
agent = DQNAgent(
|
|
state_dim=env.state_dim,
|
|
action_dim=env.action_dim,
|
|
hidden_layers=config["agent"]["hidden_layers"],
|
|
learning_rate=config["agent"]["learning_rate"],
|
|
gamma=config["agent"]["gamma"],
|
|
device=config["agent"]["device"],
|
|
)
|
|
|
|
if model_path is None:
|
|
model_path = os.path.join(
|
|
config["training"]["checkpoint_dir"], "model_final.pt"
|
|
)
|
|
|
|
if not os.path.exists(model_path):
|
|
print(f"Model not found at {model_path}")
|
|
return
|
|
|
|
agent.load(model_path)
|
|
print(f"Loaded model from {model_path}")
|
|
|
|
num_episodes = config["testing"]["num_episodes"]
|
|
render = config["testing"]["render"]
|
|
|
|
episode_rewards = []
|
|
episode_throughputs = []
|
|
all_densities = []
|
|
all_actions = []
|
|
|
|
print(f"Testing for {num_episodes} episodes...")
|
|
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
episode_reward = 0
|
|
episode_densities = []
|
|
episode_actions = []
|
|
|
|
while True:
|
|
action = agent.select_action(state, training=False)
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
episode_reward += reward
|
|
episode_densities.append(info["densities"].copy())
|
|
episode_actions.append(action)
|
|
|
|
state = next_state
|
|
|
|
if done:
|
|
break
|
|
|
|
episode_rewards.append(episode_reward)
|
|
avg_throughput = np.mean([m["throughput"] for m in env.episode_metrics])
|
|
episode_throughputs.append(avg_throughput)
|
|
|
|
if episode == 0:
|
|
all_densities = episode_densities
|
|
all_actions = episode_actions
|
|
|
|
print(
|
|
f"Episode {episode + 1}/{num_episodes} | "
|
|
f"Reward: {episode_reward:.2f} | "
|
|
f"Avg Throughput: {avg_throughput:.2f}"
|
|
)
|
|
|
|
print("\n" + "="*50)
|
|
print("Testing Summary:")
|
|
print(f"Average Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
|
|
print(f"Average Throughput: {np.mean(episode_throughputs):.2f} ± {np.std(episode_throughputs):.2f}")
|
|
print("="*50)
|
|
|
|
if render and len(all_densities) > 0:
|
|
# Use output_dir if specified, otherwise use config log_dir
|
|
log_dir = output_dir if output_dir else config["training"]["log_dir"]
|
|
plot_test_results(all_densities, all_actions, env, log_dir)
|
|
|
|
|
|
def plot_test_results(densities, actions, env, log_dir):
|
|
"""Plot test results with density heatmap and actions."""
|
|
densities_array = np.array(densities)
|
|
actions_array = np.array(actions)
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
|
|
|
|
im = axes[0].imshow(
|
|
densities_array.T,
|
|
aspect='auto',
|
|
cmap='YlOrRd',
|
|
interpolation='nearest'
|
|
)
|
|
axes[0].set_xlabel("Time Step")
|
|
axes[0].set_ylabel("Cell Index")
|
|
axes[0].set_title("Traffic Density Heatmap (vehicles/km)")
|
|
plt.colorbar(im, ax=axes[0])
|
|
|
|
time_steps = range(len(actions_array))
|
|
speed_limits = [env.speed_actions[a] for a in actions_array]
|
|
axes[1].plot(time_steps, speed_limits, 'b-', linewidth=2)
|
|
axes[1].set_xlabel("Time Step")
|
|
axes[1].set_ylabel("Speed Limit (m/s)")
|
|
axes[1].set_title("Speed Limit Control Actions")
|
|
axes[1].grid(True)
|
|
axes[1].set_ylim([env.min_speed_limit - 2, env.max_speed_limit + 2])
|
|
|
|
plt.tight_layout()
|
|
plot_path = os.path.join(log_dir, "test_results.png")
|
|
plt.savefig(plot_path)
|
|
print(f"\nTest visualization saved to {plot_path}")
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test()
|