diff --git a/scripts/evaluate_models.py b/scripts/evaluate_models.py index 3fe885a..422cc8a 100644 --- a/scripts/evaluate_models.py +++ b/scripts/evaluate_models.py @@ -483,6 +483,12 @@ def select_deterministic_action(model_name: str, agent, state: np.ndarray) -> np return action +def select_no_control_action(env: SUMOEdgeVSLEnvironment) -> np.ndarray: + if env.num_controlled_edges <= 0: + return np.zeros(0, dtype=np.int64) + return np.full(env.num_controlled_edges, env.action_dim - 1, dtype=np.int64) + + def resolve_logged_action_info( model_name: str, env: SUMOEdgeVSLEnvironment, @@ -491,7 +497,7 @@ def resolve_logged_action_info( edge_idx: int, action_speed_kmh: float, ) -> Tuple[int, float]: - if model_name == BASELINE_NAME or not action_applied_mask[edge_idx]: + if not action_applied_mask[edge_idx]: return -1, action_speed_kmh controlled_idx = edge_idx - env.controlled_edge_start_index if action is None or controlled_idx < 0 or controlled_idx >= len(action): @@ -550,8 +556,8 @@ def evaluate_single_model( while True: if model_name == BASELINE_NAME: - action = None - next_state, reward, done, info = env.step(action=None, apply_control=False) + action = select_no_control_action(env) + next_state, reward, done, info = env.step(action, apply_control=True) else: action = select_deterministic_action(model_name, agent, state) next_state, reward, done, info = env.step(action, apply_control=True)