From 79000996f4f24178e4dc17ec4ba26f946407f64d Mon Sep 17 00:00:00 2001 From: Maple-YZ Date: Tue, 14 Apr 2026 01:04:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=E6=97=A0=E7=AD=96=E7=95=A5=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=E8=B0=83=E6=95=B4=E4=B8=BA=E5=9F=BA=E4=BA=8Etraci?= =?UTF-8?q?=E7=9A=84=E6=9C=80=E9=AB=98=E9=99=90=E9=80=9F=E8=AE=BE=E7=BD=AE?= =?UTF-8?q?=E4=BB=A5=E5=AF=B9=E6=AF=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluate_models.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 3fe885a..422cc8a 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -483,6 +483,12 @@ def select_deterministic_action(model_name: str, agent, state: np.ndarray) -> np return action +def select_no_control_action(env: SUMOEdgeVSLEnvironment) -> np.ndarray: + if env.num_controlled_edges <= 0: + return np.zeros(0, dtype=np.int64) + return np.full(env.num_controlled_edges, env.action_dim - 1, dtype=np.int64) + + def resolve_logged_action_info( model_name: str, env: SUMOEdgeVSLEnvironment, @@ -491,7 +497,7 @@ def resolve_logged_action_info( edge_idx: int, action_speed_kmh: float, ) -> Tuple[int, float]: - if model_name == BASELINE_NAME or not action_applied_mask[edge_idx]: + if not action_applied_mask[edge_idx]: return -1, action_speed_kmh controlled_idx = edge_idx - env.controlled_edge_start_index if action is None or controlled_idx < 0 or controlled_idx >= len(action): @@ -550,8 +556,8 @@ def evaluate_single_model( while True: if model_name == BASELINE_NAME: - action = None - next_state, reward, done, info = env.step(action=None, apply_control=False) + action = select_no_control_action(env) + next_state, reward, done, info = env.step(action, apply_control=True) else: action = select_deterministic_action(model_name, agent, state) next_state, reward, done, info = env.step(action, apply_control=True)