diff --git a/evaluate_models.py b/evaluate_models.py index baddc84..c3dde69 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -64,6 +64,7 @@ def evaluate_model(model_type, checkpoint_dir, config): state = env.reset(seed=42) speed_data = [] density_data = [] + action_data = [] brake_counts = [] speed_variances = [] throughputs = [] @@ -77,6 +78,7 @@ def evaluate_model(model_type, checkpoint_dir, config): else: action, _, _ = agent.select_action(state, deterministic=True) + action_data.append(action.copy()) state, reward, done, info = env.step(action) if not done: @@ -132,6 +134,7 @@ def evaluate_model(model_type, checkpoint_dir, config): return { "speed_data": np.array(speed_data), "density_data": np.array(density_data), + "action_data": np.array(action_data), "brake_counts": brake_counts, "speed_variances": speed_variances, "throughputs": throughputs @@ -150,6 +153,9 @@ def plot_results(results_dict, output_dir): density_df = pd.DataFrame(results["density_data"]) density_df.to_csv(os.path.join(output_dir, f"{name}_density.csv"), index=False) + action_df = pd.DataFrame(results["action_data"]) + action_df.to_csv(os.path.join(output_dir, f"{name}_action.csv"), index=False) + metrics_df = pd.DataFrame({ "brake_counts": results["brake_counts"], "speed_variances": results["speed_variances"], @@ -252,6 +258,7 @@ def evaluate_no_control(config): state = env.reset(seed=42) speed_data = [] density_data = [] + action_data = [] brake_counts = [] speed_variances = [] throughputs = [] @@ -259,6 +266,7 @@ def evaluate_no_control(config): while not done: action = np.array([4] * env.num_edges) + action_data.append(action.copy()) state, reward, done, info = env.step(action) if not done: @@ -306,6 +314,7 @@ def evaluate_no_control(config): return { "speed_data": np.array(speed_data), "density_data": np.array(density_data), + "action_data": np.array(action_data), "brake_counts": brake_counts, "speed_variances": speed_variances, "throughputs": throughputs diff --git a/plot_actions.py b/plot_actions.py new file mode 100644 index 0000000..f5b54b5 --- /dev/null +++ b/plot_actions.py @@ -0,0 +1,32 @@ +"""从CSV绘制决策时空图""" +import os +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") + +output_dir = "evaluation_results" +models = ["PPO", "APPO", "DQN", "No Control"] +speed_map = [40, 60, 80, 100, 120] + +fig, axes = plt.subplots(2, 2, figsize=(16, 12)) +axes = axes.flatten() + +for idx, model in enumerate(models): + csv_path = os.path.join(output_dir, f"{model}_action.csv") + df = pd.read_csv(csv_path) + actions = df.values + vsl_speeds = np.vectorize(lambda x: speed_map[int(x)])(actions) + + im = axes[idx].imshow(vsl_speeds.T, aspect="auto", cmap="RdYlGn", + vmin=40, vmax=120, origin="lower") + axes[idx].set_title(f"{model} VSL Decision") + axes[idx].set_xlabel("Time Step") + axes[idx].set_ylabel("Control Edge Index") + cbar = plt.colorbar(im, ax=axes[idx], label="VSL (km/h)") + cbar.set_ticks([40, 60, 80, 100, 120]) + +plt.tight_layout() +plt.savefig(os.path.join(output_dir, "action_heatmaps.png"), dpi=150) +print("决策时空图已保存")