"""SUMO VSL environment with 1000 m corridor control segments.""" import os import sys import numpy as np import xml.etree.ElementTree as ET from typing import Tuple, Dict, List, Optional, Set try: import sumo as _sumo_pkg _tools = os.path.join(_sumo_pkg.SUMO_HOME, "tools") if _tools not in sys.path: sys.path.insert(0, _tools) except ImportError: pass import traci import traci.constants as tc from envs.network_parser import SUMONetworkParser from envs.reward_system import RewardCalculator, RewardConfig from utils.experiment_corridor import prepare_experiment_corridor_assets from utils.reward_baseline import resolve_baseline_dir, wait_for_episode_baseline class SUMOEdgeVSLEnvironment: """Edge-based VSL environment over derived corridor control segments.""" def __init__(self, config: dict): sumo_cfg = config["sumo"] env_cfg = config["environment"] self.net_file = sumo_cfg["net_file"] self.route_file = sumo_cfg["route_file"] self.step_length = sumo_cfg["step_length"] self.begin_time = sumo_cfg["begin_time"] self.end_time = sumo_cfg["end_time"] self.use_gui = sumo_cfg.get("gui", False) self.no_warnings = sumo_cfg.get("no_warnings", True) runtime_cfg = config.get("runtime", {}) self.runtime_output_dir = runtime_cfg.get("output_dir") self.runtime_metrics_subdir = runtime_cfg.get("metrics_subdir", "sumo_metrics") self.run_timestamp = runtime_cfg.get("run_timestamp") self.evaluation_mode = bool(runtime_cfg.get("evaluation_mode", False)) self.runtime_detector_add_file: Optional[str] = None self.runtime_enex_add_file: Optional[str] = None self.collect_detector_cells = runtime_cfg.get( "collect_detector_cells", env_cfg.get("collect_detector_cells", False), ) self.use_vehicle_subscriptions = runtime_cfg.get( "use_vehicle_subscriptions", env_cfg.get("use_vehicle_subscriptions", True), ) self.collect_trip_events = runtime_cfg.get( "collect_trip_events", env_cfg.get("collect_trip_events", False), ) self.control_interval = env_cfg["control_interval"] self.steps_per_action = int(self.control_interval / self.step_length) self.warmup_time = int(env_cfg.get("warmup_time", 900)) self.episode_length = int( (self.end_time - self.begin_time - self.warmup_time) / self.control_interval ) self.speed_actions_kmh = np.array(env_cfg["speed_actions_kmh"], dtype=float) self.speed_actions_ms = self.speed_actions_kmh / 3.6 self.num_speed_actions = len(self.speed_actions_kmh) self.free_flow_speed = env_cfg["free_flow_speed"] self.base_corridor_edges: List[str] = env_cfg["control_edges"] self.control_segment_length_m = float(env_cfg.get("control_segment_length_m", 1000.0)) self.detector_spacing_m = float(env_cfg.get("detector_spacing_m", 100.0)) self.detector_start_offset_m = float( env_cfg.get("detector_start_offset_m", self.detector_spacing_m * 0.5) ) self.passive_prefix_segment_count = int(env_cfg.get("passive_prefix_segment_count", 0)) self.incident_cfg = dict(env_cfg.get("incident", {}) or {}) self.incident_enabled = bool(self.incident_cfg.get("enabled", False)) self.incident_start_delay_min_range = self._parse_numeric_range( self.incident_cfg.get("start_delay_min_range"), default=(10.0, 20.0), lower_bound=0.0, ) self.incident_duration_min_range = self._parse_numeric_range( self.incident_cfg.get("duration_min_range"), default=(10.0, 20.0), lower_bound=0.0, ) self.incident_downstream_fraction_range = self._parse_numeric_range( self.incident_cfg.get("downstream_fraction_range"), default=(0.65, 0.95), lower_bound=0.0, upper_bound=1.0, ) self.incident_position_buffer_m = float( self.incident_cfg.get("target_position_buffer_m", 25.0) ) self.incident_position_tolerance_m = float( self.incident_cfg.get("position_tolerance_m", 5.0) ) self.incident_stopped_speed_threshold_ms = float( self.incident_cfg.get("stopped_speed_threshold_ms", 0.1) ) corridor_assets = prepare_experiment_corridor_assets( net_file=self.net_file, route_file=self.route_file, corridor_edges=self.base_corridor_edges, control_segment_length_m=self.control_segment_length_m, detector_spacing_m=self.detector_spacing_m, detector_start_offset_m=self.detector_start_offset_m, output_root=env_cfg.get("generated_asset_dir"), ) self.net_file = corridor_assets.net_file self.route_file = corridor_assets.route_file self.detector_add_file = corridor_assets.detector_add_file self.enex_add_file = corridor_assets.enex_add_file self._detector_add_template = self.detector_add_file self._enex_add_template = self.enex_add_file self.control_segments: List[Dict] = corridor_assets.control_segments self.detector_cell_defs: List[Dict] = corridor_assets.detector_cells self.control_edges: List[str] = [ segment["segment_id"] for segment in self.control_segments ] self.segment_edge_map: Dict[str, List[str]] = { segment["segment_id"]: list(segment["edge_ids"]) for segment in self.control_segments } self.segment_detector_map: Dict[str, List[str]] = { segment["segment_id"]: list(segment.get("first_detector_group_ids", [])) for segment in self.control_segments } self.num_edges = len(self.control_edges) self.passive_prefix_segment_count = min(self.passive_prefix_segment_count, self.num_edges) self.passive_segment_indices = set(range(self.passive_prefix_segment_count)) self.passive_segments: List[str] = self.control_edges[: self.passive_prefix_segment_count] self.active_control_segments: List[str] = self.control_edges[self.passive_prefix_segment_count :] self.controlled_edge_start_index = self.passive_prefix_segment_count self.controlled_edge_indices = list(range(self.controlled_edge_start_index, self.num_edges)) self.num_controlled_edges = len(self.controlled_edge_indices) self.physical_control_edges: List[str] = list(corridor_assets.physical_edge_ids) self.control_edge_set = set(self.physical_control_edges) self.reward_cfg = env_cfg.get("reward", {}) self.reward_config = RewardConfig.from_dict( self.reward_cfg, speed_actions_ms=self.speed_actions_ms, ) self.parser = SUMONetworkParser( detector_add_file=self._detector_add_template, net_file=self.net_file, ) self.controlled_length_km = corridor_assets.controlled_length_m / 1000.0 self.physical_edge_ranges = self._build_physical_edge_ranges() self.default_edge_speeds = self._build_default_segment_speeds() self.max_segment_speeds = self.default_edge_speeds.copy() self.reward_calculator = RewardCalculator( config=self.reward_config, controlled_edge_start_index=self.controlled_edge_start_index, evaluation_mode=self.evaluation_mode, ) self.reward_baseline_dir = resolve_baseline_dir(config, self.run_timestamp) self.reward_calculator.set_step_baseline({}) self.action_dims = [self.num_speed_actions] * self.num_controlled_edges self.features_per_edge = 3 self._state_dim = (self.features_per_edge + 1) * self.num_edges + 3 + 1 self.current_step = 0 self._sumo_running = False self._episode_count = 0 self.current_edge_speeds = self.default_edge_speeds.copy() self._prev_edge_speeds = self.default_edge_speeds.copy() self._last_reward = 0.0 self.episode_metrics: List[Dict] = [] self._tracked_vehicle_ids: Set[str] = set() self._mainline_depart_times: Dict[str, float] = {} self._active_mainline_vehicle_ids: Set[str] = set() self._completed_mainline_travel_times: List[float] = [] self._interval_mainline_travel_times: List[float] = [] self._episode_rng = np.random.default_rng() self._incident_state: Dict[str, object] = {} self._reward_baseline_episode = 0 self._reward_baseline_loaded_step = 0 print("SUMO Edge VSL Environment initialized") print(f" Control segments: {self.num_edges}") print(f" Passive prefix segments: {self.passive_prefix_segment_count}") print(f" Active controlled segments: {self.num_controlled_edges}") print(f" Physical corridor edges: {len(self.physical_control_edges)}") print(f" Action: MultiDiscrete {self.action_dims}") print(f" State dim: {self._state_dim}") print(f" Episode length: {self.episode_length} steps") if self.incident_enabled: print(" Incident scenario: enabled") @property def state_dim(self) -> int: return self._state_dim @property def action_dim(self) -> int: return self.num_speed_actions def _build_default_segment_speeds(self) -> np.ndarray: default_speeds = [] for segment in self.control_segments: edge_speeds = [] for edge_id in segment["edge_ids"]: edge_info = self.parser.edge_info.get(edge_id) if edge_info is not None: edge_speeds.append(float(edge_info.speed_limit)) default_speeds.append( float(min(edge_speeds)) if edge_speeds else float(self.free_flow_speed) ) return np.array(default_speeds, dtype=float) @staticmethod def _parse_numeric_range( raw_value, *, default: Tuple[float, float], lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, ) -> Tuple[float, float]: if isinstance(raw_value, (list, tuple)) and len(raw_value) == 2: low, high = raw_value else: low, high = default low = float(low) high = float(high) if low > high: low, high = high, low if lower_bound is not None: low = max(low, lower_bound) high = max(high, lower_bound) if upper_bound is not None: low = min(low, upper_bound) high = min(high, upper_bound) if low > high: low = high return low, high def _build_physical_edge_ranges(self) -> List[Dict]: edge_ranges: List[Dict] = [] cumulative = 0.0 for edge_id in self.physical_control_edges: edge_info = self.parser.edge_info.get(edge_id) length_m = float(edge_info.length) if edge_info is not None else 0.0 edge_ranges.append( { "edge_id": edge_id, "start_m": cumulative, "end_m": cumulative + length_m, "length_m": length_m, } ) cumulative += length_m return edge_ranges def _locate_physical_edge_range(self, distance_m: float) -> Optional[Dict]: for edge_range in self.physical_edge_ranges: if edge_range["start_m"] <= distance_m <= edge_range["end_m"]: return edge_range if self.physical_edge_ranges: return self.physical_edge_ranges[-1] return None def _reset_incident_runtime(self, seed: Optional[int]): self._episode_rng = np.random.default_rng(seed) self._incident_state = { "enabled": False, "pending": False, "commanded": False, "active": False, "completed": False, "vehicle_id": "", "target_edge_id": "", "target_position_m": float("nan"), "target_distance_m": float("nan"), "trigger_time_s": float("nan"), "duration_s": 0.0, "command_time_s": float("nan"), "blocking_start_time_s": float("nan"), "release_time_s": float("nan"), "released_time_s": float("nan"), "lane_index": -1, } if not self.incident_enabled or not self.physical_edge_ranges: return downstream_low, downstream_high = self.incident_downstream_fraction_range target_distance_m = float( self._episode_rng.uniform(downstream_low, downstream_high) * max(self.controlled_length_km * 1000.0, 1.0) ) edge_range = self._locate_physical_edge_range(target_distance_m) if edge_range is None: return edge_length_m = max(float(edge_range["length_m"]), 1.0) position_buffer_m = min( max(self.incident_position_buffer_m, 1.0), max(edge_length_m * 0.5 - 1.0, 1.0), ) target_position_m = float( np.clip( target_distance_m - float(edge_range["start_m"]), position_buffer_m, max(edge_length_m - position_buffer_m, position_buffer_m), ) ) start_delay_s = float( self._episode_rng.uniform(*self.incident_start_delay_min_range) * 60.0 ) duration_s = float( self._episode_rng.uniform(*self.incident_duration_min_range) * 60.0 ) self._incident_state.update( { "enabled": True, "pending": True, "target_edge_id": str(edge_range["edge_id"]), "target_position_m": target_position_m, "target_distance_m": target_distance_m, "trigger_time_s": float(self.begin_time + self.warmup_time) + start_delay_s, "duration_s": duration_s, } ) def _is_incident_vehicle_present(self, vehicle_id: str) -> bool: if not vehicle_id: return False try: return vehicle_id in set(traci.vehicle.getIDList()) except Exception: return False def _clear_incident_assignment(self): self._incident_state["commanded"] = False self._incident_state["active"] = False self._incident_state["vehicle_id"] = "" self._incident_state["command_time_s"] = float("nan") self._incident_state["blocking_start_time_s"] = float("nan") self._incident_state["release_time_s"] = float("nan") self._incident_state["lane_index"] = -1 def _try_command_incident_vehicle(self, sim_time: float): incident = self._incident_state if not incident.get("enabled", False): return target_edge_id = str(incident.get("target_edge_id", "")) if not target_edge_id: return edge_info = self.parser.edge_info.get(target_edge_id) traffic_lane_indices = set(edge_info.traffic_lane_indices) if edge_info else set() target_position_m = float(incident.get("target_position_m", float("nan"))) try: vehicle_ids = list(traci.edge.getLastStepVehicleIDs(target_edge_id)) except Exception: vehicle_ids = [] candidates = [] for veh_id in vehicle_ids: try: if traci.vehicle.getRoadID(veh_id) != target_edge_id: continue lane_index = int(traci.vehicle.getLaneIndex(veh_id)) lane_position_m = float(traci.vehicle.getLanePosition(veh_id)) speed_ms = max(float(traci.vehicle.getSpeed(veh_id)), 0.0) except Exception: continue if traffic_lane_indices and lane_index not in traffic_lane_indices: continue remaining_distance_m = target_position_m - lane_position_m if remaining_distance_m <= max(self.incident_position_tolerance_m, 1.0): continue min_braking_distance_m = max(40.0, speed_ms * 5.0) if remaining_distance_m <= min_braking_distance_m: continue candidates.append((remaining_distance_m, veh_id, lane_index)) if not candidates: return for _, vehicle_id, lane_index in sorted(candidates, key=lambda item: item[0]): try: traci.vehicle.setStop( vehicle_id, target_edge_id, pos=target_position_m, laneIndex=lane_index, duration=float(incident["duration_s"]), ) except Exception: continue incident["commanded"] = True incident["vehicle_id"] = str(vehicle_id) incident["lane_index"] = int(lane_index) incident["command_time_s"] = float(sim_time) return def _update_incident_state(self, sim_time: float): incident = self._incident_state if not incident.get("enabled", False) or incident.get("completed", False): return vehicle_id = str(incident.get("vehicle_id", "")) if incident.get("active", False): if not self._is_incident_vehicle_present(vehicle_id): incident["active"] = False incident["pending"] = False incident["completed"] = True incident["released_time_s"] = float(sim_time) return release_time_s = float(incident.get("release_time_s", float("inf"))) if sim_time >= release_time_s: try: if hasattr(traci.vehicle, "getStopState") and traci.vehicle.getStopState(vehicle_id) != 0: traci.vehicle.resume(vehicle_id) except Exception: pass incident["active"] = False incident["pending"] = False incident["completed"] = True incident["released_time_s"] = float(sim_time) return if incident.get("commanded", False): if not self._is_incident_vehicle_present(vehicle_id): self._clear_incident_assignment() else: try: road_id = traci.vehicle.getRoadID(vehicle_id) lane_position_m = float(traci.vehicle.getLanePosition(vehicle_id)) speed_ms = float(traci.vehicle.getSpeed(vehicle_id)) except Exception: self._clear_incident_assignment() else: position_error_m = abs( lane_position_m - float(incident.get("target_position_m", 0.0)) ) if ( road_id == incident.get("target_edge_id") and position_error_m <= self.incident_position_tolerance_m and speed_ms <= self.incident_stopped_speed_threshold_ms ): incident["active"] = True incident["blocking_start_time_s"] = float(sim_time) incident["release_time_s"] = float(sim_time) + float(incident["duration_s"]) if incident.get("commanded", False): return trigger_time_s = float(incident.get("trigger_time_s", float("inf"))) if sim_time < trigger_time_s: return self._try_command_incident_vehicle(sim_time) def _append_incident_info(self, info: Dict): incident = self._incident_state info["incident_enabled"] = bool(incident.get("enabled", False)) info["incident_pending"] = bool(incident.get("pending", False)) info["incident_commanded"] = bool(incident.get("commanded", False)) info["incident_active"] = bool(incident.get("active", False)) info["incident_completed"] = bool(incident.get("completed", False)) info["incident_vehicle_id"] = str(incident.get("vehicle_id", "")) info["incident_target_edge_id"] = str(incident.get("target_edge_id", "")) info["incident_target_position_m"] = float( incident.get("target_position_m", float("nan")) ) info["incident_target_distance_m"] = float( incident.get("target_distance_m", float("nan")) ) info["incident_trigger_time_s"] = float(incident.get("trigger_time_s", float("nan"))) info["incident_duration_s"] = float(incident.get("duration_s", 0.0)) info["incident_command_time_s"] = float( incident.get("command_time_s", float("nan")) ) info["incident_blocking_start_time_s"] = float( incident.get("blocking_start_time_s", float("nan")) ) info["incident_release_time_s"] = float( incident.get("release_time_s", float("nan")) ) info["incident_released_time_s"] = float( incident.get("released_time_s", float("nan")) ) info["incident_lane_index"] = int(incident.get("lane_index", -1)) def _load_episode_reward_baseline(self, episode: int) -> Dict[int, Dict[str, float]]: mode = str(self.reward_config.mode).lower() if mode in {"paired_no_control", "episode_baseline"}: return wait_for_episode_baseline( baseline_dir=self.reward_baseline_dir, episode=episode, min_step=self._reward_baseline_loaded_step + 1, timeout_s=self.reward_config.baseline_wait_timeout_s, poll_interval_s=self.reward_config.baseline_poll_interval_s, ) return {} def _sync_episode_reward_baseline(self, min_step: int) -> None: mode = str(self.reward_config.mode).lower() if mode not in {"paired_no_control", "episode_baseline"}: return if self._reward_baseline_episode != self._episode_count: self._reward_baseline_episode = self._episode_count self._reward_baseline_loaded_step = 0 self.reward_calculator.set_step_baseline({}) if self._reward_baseline_loaded_step >= min_step: return baseline = wait_for_episode_baseline( baseline_dir=self.reward_baseline_dir, episode=self._episode_count, min_step=min_step, timeout_s=self.reward_config.baseline_wait_timeout_s, poll_interval_s=self.reward_config.baseline_poll_interval_s, ) self.reward_calculator.set_step_baseline(baseline) self._reward_baseline_loaded_step = max(baseline) if baseline else 0 def _start_sumo(self, seed: Optional[int] = None): if self._sumo_running: self._close_sumo() self._prepare_runtime_additional_files() binary_name = "sumo-gui" if self.use_gui else "sumo" try: import sumolib sumo_binary = sumolib.checkBinary(binary_name) except Exception: sumo_binary = binary_name detector_add_file = self.runtime_detector_add_file or self.detector_add_file enex_add_file = self.runtime_enex_add_file or self.enex_add_file cmd = [ sumo_binary, "-n", self.net_file, "-r", self.route_file, "-a", f"{detector_add_file},{enex_add_file}", "--step-length", str(self.step_length), "-b", str(self.begin_time), "-e", str(self.end_time), "--collision.action", "warn", "--no-step-log", "true", "--quit-on-end", "true", ] if self.no_warnings: cmd += ["--no-warnings", "true"] if seed is not None: cmd += ["--seed", str(seed)] if self.use_gui: cmd += ["--start", "true", "--gui-settings-file", "sumo_resource/gui.settings.xml"] traci.start(cmd, label=f"vsl_{self._episode_count}") self._sumo_running = True @staticmethod def _to_sumo_path(path: str) -> str: return os.path.abspath(path).replace("\\", "/") def _rewrite_additional_file(self, template_path: str, runtime_add_path: str, output_xml_path: str): tree = ET.parse(template_path) root = tree.getroot() for elem in root.iter(): if "file" in elem.attrib: elem.set("file", output_xml_path) tree.write(runtime_add_path, encoding="utf-8", xml_declaration=True) def _prepare_runtime_additional_files(self): if not self.runtime_output_dir: self.runtime_detector_add_file = None self.runtime_enex_add_file = None return output_dir = os.path.join( os.path.abspath(self.runtime_output_dir), self.runtime_metrics_subdir, ) os.makedirs(output_dir, exist_ok=True) suffix = f"ep{self._episode_count:04d}" detector_output_file = self._to_sumo_path( os.path.join(output_dir, f"metrics_il_output_{suffix}.xml") ) enex_output_file = self._to_sumo_path( os.path.join(output_dir, f"metrics_enex_output_{suffix}.xml") ) detector_add_file = os.path.join(output_dir, f"runtime_metrics_il_{suffix}.add.xml") enex_add_file = os.path.join(output_dir, f"runtime_metrics_enex_{suffix}.add.xml") self._rewrite_additional_file( self._detector_add_template, detector_add_file, detector_output_file, ) self._rewrite_additional_file( self._enex_add_template, enex_add_file, enex_output_file, ) self.runtime_detector_add_file = detector_add_file self.runtime_enex_add_file = enex_add_file def _close_sumo(self): if self._sumo_running: try: traci.close() except Exception: pass self._sumo_running = False self._tracked_vehicle_ids.clear() self._mainline_depart_times.clear() self._active_mainline_vehicle_ids.clear() def reset(self, seed: Optional[int] = None) -> np.ndarray: self._episode_count += 1 self.current_step = 0 self.episode_metrics = [] self.current_edge_speeds = self.default_edge_speeds.copy() self._prev_edge_speeds = self.default_edge_speeds.copy() self._last_reward = 0.0 self._tracked_vehicle_ids.clear() self._mainline_depart_times = {} self._active_mainline_vehicle_ids = set() self._completed_mainline_travel_times = [] self._interval_mainline_travel_times = [] self._reset_incident_runtime(seed) mode = str(self.reward_config.mode).lower() if mode in {"paired_no_control", "episode_baseline"}: self._reward_baseline_episode = self._episode_count self._reward_baseline_loaded_step = 0 self.reward_calculator.set_step_baseline({}) else: self.reward_calculator.set_step_baseline(self._load_episode_reward_baseline(self._episode_count)) self._start_sumo(seed=seed) warmup_steps = int(self.warmup_time / self.control_interval) for _ in range(warmup_steps): for _ in range(self.steps_per_action): traci.simulationStep() return self._collect_state() def step( self, action: Optional[np.ndarray], apply_control: bool = True, ) -> Tuple[np.ndarray, float, bool, Dict]: self._prev_edge_speeds = self.current_edge_speeds.copy() if apply_control: if action is None: raise ValueError("action must be provided when apply_control=True") edge_speeds = self._decode_action(action) else: edge_speeds = self.default_edge_speeds.copy() self.current_edge_speeds = edge_speeds if apply_control: self._apply_vsl(edge_speeds) self._interval_arrived = 0 self._interval_departed = 0 self._interval_departed_vehicle_events = [] self._interval_arrived_vehicle_events = [] self._interval_mainline_travel_times = [] for _ in range(self.steps_per_action): traci.simulationStep() sim_time = float(traci.simulation.getTime()) self._update_incident_state(sim_time) self._interval_arrived += traci.simulation.getArrivedNumber() self._interval_departed += traci.simulation.getDepartedNumber() self._update_mainline_trip_tracking(sim_time) detector_data = self._get_edge_detector_data() state = self._collect_state(detector_data) info = self._collect_runtime_metrics(detector_data) info["step"] = self.current_step + 1 self._sync_episode_reward_baseline(int(info["step"])) info["detector_cells"] = self._collect_all_detector_cells() if self.collect_detector_cells else [] reward = self._calculate_reward(info) self._last_reward = reward self.current_step += 1 done = self.current_step >= self.episode_length info["reward"] = reward info["step"] = self.current_step info["edge_speeds_kmh"] = (edge_speeds * 3.6).tolist() info["action_applied_mask"] = [ bool(apply_control and idx not in self.passive_segment_indices) for idx in range(self.num_edges) ] if self.collect_trip_events: info["departed_vehicle_events"] = list(self._interval_departed_vehicle_events) info["arrived_vehicle_events"] = list(self._interval_arrived_vehicle_events) self._append_incident_info(info) self.episode_metrics.append(info) if done: self._close_sumo() return state, reward, done, info def close(self): self._close_sumo() def _update_mainline_trip_tracking(self, sim_time: float): departed_ids = [] arrived_ids = [] try: departed_ids = list(traci.simulation.getDepartedIDList()) except Exception: departed_ids = [] try: arrived_ids = list(traci.simulation.getArrivedIDList()) except Exception: arrived_ids = [] for veh_id in departed_ids: is_mainline = False try: route_edges = traci.vehicle.getRoute(veh_id) is_mainline = any(edge_id in self.control_edge_set for edge_id in route_edges) except Exception: is_mainline = False if is_mainline: self._mainline_depart_times[veh_id] = float(sim_time) self._active_mainline_vehicle_ids.add(veh_id) if self.collect_trip_events: self._interval_departed_vehicle_events.append( { "vehicle_id": veh_id, "sim_time": sim_time, "is_mainline": bool(is_mainline), } ) for veh_id in arrived_ids: if self.collect_trip_events: self._interval_arrived_vehicle_events.append( { "vehicle_id": veh_id, "sim_time": sim_time, } ) if veh_id not in self._active_mainline_vehicle_ids: continue depart_time = self._mainline_depart_times.pop(veh_id, None) self._active_mainline_vehicle_ids.discard(veh_id) if depart_time is None: continue travel_time = float(sim_time) - float(depart_time) if travel_time < 0: continue self._interval_mainline_travel_times.append(travel_time) self._completed_mainline_travel_times.append(travel_time) def _decode_action(self, action: np.ndarray) -> np.ndarray: action_array = np.asarray(action, dtype=np.int64).reshape(-1) if action_array.size != self.num_controlled_edges: raise ValueError( f"Expected {self.num_controlled_edges} control actions, got {action_array.size}" ) edge_speeds = self.default_edge_speeds.copy() if self.num_controlled_edges == 0: return edge_speeds requested_speeds = np.array( [self.speed_actions_ms[int(a)] for a in action_array], dtype=float, ) edge_speeds[self.controlled_edge_start_index :] = np.minimum( requested_speeds, self.max_segment_speeds[self.controlled_edge_start_index :], ) return edge_speeds def _apply_vsl(self, edge_speeds: np.ndarray): for idx, segment_id in enumerate(self.control_edges): if idx in self.passive_segment_indices: continue for edge_id in self.segment_edge_map.get(segment_id, []): edge_info = self.parser.edge_info.get(edge_id) original_edge_speed = ( float(edge_info.speed_limit) if edge_info is not None else float(edge_speeds[idx]) ) safe_speed = min(float(edge_speeds[idx]), original_edge_speed) traci.edge.setMaxSpeed(edge_id, safe_speed) def _get_edge_detector_data(self) -> Tuple[List[float], List[float], List[int], List[float]]: speeds, occs, counts, valid_speeds = [], [], [], [] for segment_id in self.control_edges: det_ids = self.segment_detector_map.get(segment_id, []) if not det_ids: speeds.append(self.free_flow_speed) occs.append(0.0) counts.append(0) continue lane_speeds, lane_occs, lane_counts = [], [], [] for det_id in det_ids: try: spd = traci.inductionloop.getLastIntervalMeanSpeed(det_id) occ = traci.inductionloop.getLastIntervalOccupancy(det_id) cnt = traci.inductionloop.getLastIntervalVehicleNumber(det_id) if spd > 0: lane_speeds.append(spd) if occ >= 0: lane_occs.append(occ) if cnt >= 0: lane_counts.append(cnt) except Exception: pass speeds.append(np.mean(lane_speeds) if lane_speeds else self.free_flow_speed) occs.append(np.mean(lane_occs) if lane_occs else 0.0) counts.append(sum(lane_counts)) if lane_speeds: valid_speeds.append(np.mean(lane_speeds)) return speeds, occs, counts, valid_speeds def _collect_all_detector_cells(self) -> List[Dict]: detector_rows = [] for cell in self.detector_cell_defs: lane_speeds = [] lane_occs = [] lane_counts = [] for det_id in cell["detector_ids"]: try: speed = traci.inductionloop.getLastIntervalMeanSpeed(det_id) occupancy = traci.inductionloop.getLastIntervalOccupancy(det_id) count = traci.inductionloop.getLastIntervalVehicleNumber(det_id) except Exception: continue if speed > 0: lane_speeds.append(speed) if occupancy >= 0: lane_occs.append(occupancy) if count >= 0: lane_counts.append(count) detector_rows.append( { "edge_index": int(cell["segment_index"]), "edge_id": cell["segment_id"], "physical_edge_id": cell["edge_id"], "pos_index": int(cell["pos_index"]), "position_m": float(cell["segment_position_m"]), "distance_m": float(cell["distance_m"]), "speed_ms": float(np.mean(lane_speeds)) if lane_speeds else np.nan, "occupancy": float(np.mean(lane_occs)) if lane_occs else np.nan, "vehicle_count": int(sum(lane_counts)), } ) return detector_rows def _collect_state( self, detector_data: Optional[Tuple[List[float], List[float], List[int], List[float]]] = None, ) -> np.ndarray: state_parts = [] if detector_data is None: speeds, occs, counts, _ = self._get_edge_detector_data() else: speeds, occs, counts, _ = detector_data for spd, occ, cnt in zip(speeds, occs, counts): mean_speed_norm = np.clip(spd / self.free_flow_speed, 0.0, 1.5) mean_occ = np.clip(occ / 100.0, 0.0, 1.0) flow_norm = min(cnt / 50.0, 1.0) state_parts.extend([mean_speed_norm, mean_occ, flow_norm]) for idx in range(self.num_edges): state_parts.append(self.current_edge_speeds[idx] / self.free_flow_speed) time_progress = self.current_step / max(self.episode_length, 1) state_parts.append(time_progress) state_parts.append(np.sin(2 * np.pi * time_progress)) state_parts.append(np.cos(2 * np.pi * time_progress)) state_parts.append(self._last_reward) return np.array(state_parts, dtype=np.float32) def _sync_vehicle_subscriptions(self, current_vehicle_ids: Set[str]): if not self.use_vehicle_subscriptions: return new_vehicle_ids = current_vehicle_ids - self._tracked_vehicle_ids for veh_id in new_vehicle_ids: try: traci.vehicle.subscribe( veh_id, (tc.VAR_ROAD_ID, tc.VAR_SPEED), ) except Exception: pass self._tracked_vehicle_ids = set(current_vehicle_ids) def _collect_controlled_vehicle_metrics(self) -> Tuple[int, int, List[float], List[float]]: try: current_vehicle_ids = set(traci.vehicle.getIDList()) except Exception: return 0, 0, [], [] if not current_vehicle_ids: self._sync_vehicle_subscriptions(set()) return 0, 0, [], [] self._sync_vehicle_subscriptions(current_vehicle_ids) controlled_vehicle_count = 0 valid_following_count = 0 relative_speed_samples: List[float] = [] ttc_samples: List[float] = [] if self.use_vehicle_subscriptions: for veh_id in current_vehicle_ids: try: results = traci.vehicle.getSubscriptionResults(veh_id) or {} except Exception: continue road_id = results.get(tc.VAR_ROAD_ID) if road_id not in self.control_edge_set: continue controlled_vehicle_count += 1 speed = results.get(tc.VAR_SPEED) relative_speed = 0.0 ttc_value = 0.0 if speed is not None: try: leader_info = traci.vehicle.getLeader( veh_id, self.reward_config.leader_gap_threshold_m, ) except Exception: leader_info = None if leader_info: leader_id, _ = leader_info try: leader_road_id = traci.vehicle.getRoadID(leader_id) leader_speed = traci.vehicle.getSpeed(leader_id) except Exception: leader_road_id = None leader_speed = None if ( leader_road_id in self.control_edge_set and leader_speed is not None and leader_speed >= 0 ): valid_following_count += 1 relative_speed = max(float(speed) - float(leader_speed), 0.0) if relative_speed > 1e-6: gap_m = max(float(leader_info[1]), 0.0) ttc_value = gap_m / relative_speed if gap_m > 0.0 else 0.0 relative_speed_samples.append(relative_speed) ttc_samples.append(ttc_value) if controlled_vehicle_count > 0: return controlled_vehicle_count, valid_following_count, relative_speed_samples, ttc_samples for veh_id in current_vehicle_ids: try: road_id = traci.vehicle.getRoadID(veh_id) except Exception: continue if road_id not in self.control_edge_set: continue controlled_vehicle_count += 1 try: speed = traci.vehicle.getSpeed(veh_id) except Exception: speed = None relative_speed = 0.0 ttc_value = 0.0 if speed is not None and speed >= 0: try: leader_info = traci.vehicle.getLeader( veh_id, self.reward_config.leader_gap_threshold_m, ) except Exception: leader_info = None if leader_info: leader_id, _ = leader_info try: leader_road_id = traci.vehicle.getRoadID(leader_id) leader_speed = traci.vehicle.getSpeed(leader_id) except Exception: leader_road_id = None leader_speed = None if ( leader_road_id in self.control_edge_set and leader_speed is not None and leader_speed >= 0 ): valid_following_count += 1 relative_speed = max(float(speed) - float(leader_speed), 0.0) if relative_speed > 1e-6: gap_m = max(float(leader_info[1]), 0.0) ttc_value = gap_m / relative_speed if gap_m > 0.0 else 0.0 relative_speed_samples.append(relative_speed) ttc_samples.append(ttc_value) return controlled_vehicle_count, valid_following_count, relative_speed_samples, ttc_samples def _collect_runtime_metrics( self, detector_data: Tuple[List[float], List[float], List[int], List[float]], ) -> Dict: info = {} throughput = self._interval_arrived * (3600.0 / self.control_interval) info["throughput"] = throughput info["arrived_count"] = self._interval_arrived info["departed_count"] = self._interval_departed edge_speeds, edge_occs, edge_counts, valid_speeds = detector_data info["edge_speeds_ms"] = edge_speeds info["edge_occupancies"] = edge_occs info["edge_vehicle_counts"] = edge_counts info["mean_speed"] = np.mean(valid_speeds) if valid_speeds else 0.0 info["mean_speed_kmh"] = info["mean_speed"] * 3.6 info["mean_occupancy"] = np.mean(edge_occs) if edge_occs else 0.0 controlled_vehicle_count, valid_following_count, relative_speed_samples, ttc_samples = ( self._collect_controlled_vehicle_metrics() ) info["num_vehicles"] = controlled_vehicle_count info["density"] = ( controlled_vehicle_count / self.controlled_length_km if self.controlled_length_km > 0 else 0.0 ) info["valid_following_count"] = int(valid_following_count) info["risky_following_count"] = int(sum(value > 0 for value in relative_speed_samples)) info["relative_speed_samples"] = relative_speed_samples info["relative_speed_variance"] = float(np.var(relative_speed_samples)) if relative_speed_samples else 0.0 info["speed_variance_norm"] = ( info["relative_speed_variance"] / max(self.free_flow_speed ** 2, 1e-6) ) info["ttc_samples"] = ttc_samples positive_ttc_samples = [value for value in ttc_samples if value > 0.0] info["ttc_min_s"] = float(min(positive_ttc_samples)) if positive_ttc_samples else float("inf") ttc_threshold_s = max(float(self.reward_config.ttc_threshold_s), 1e-6) ttc_risk_samples = [ max(0.0, min(1.0, (ttc_threshold_s - value) / ttc_threshold_s)) if value > 0.0 else 0.0 for value in ttc_samples ] info["ttc_risk_samples"] = ttc_risk_samples info["ttc_risk_rate"] = ( float(np.sum(ttc_risk_samples)) / float(controlled_vehicle_count) if controlled_vehicle_count > 0 else 0.0 ) info["critical_ttc_vehicle_count"] = int( sum(0.0 < value < ttc_threshold_s for value in ttc_samples) ) active_start = self.controlled_edge_start_index active_occs = np.asarray(edge_occs[active_start:], dtype=float) active_counts = np.asarray(edge_counts[active_start:], dtype=float) if active_occs.size > 0: bottleneck_window = active_occs[-self.reward_config.bottleneck_window_size :] bottleneck_rel_idx = int(np.argmax(bottleneck_window)) bottleneck_abs_idx = active_start + max( 0, active_occs.size - self.reward_config.bottleneck_window_size, ) + bottleneck_rel_idx info["bottleneck_segment_index"] = bottleneck_abs_idx info["bottleneck_occupancy"] = float(bottleneck_window[bottleneck_rel_idx]) else: info["bottleneck_segment_index"] = -1 info["bottleneck_occupancy"] = 0.0 info["downstream_mainline_outflow"] = ( float(active_counts[-1]) * (3600.0 / self.control_interval) if active_counts.size > 0 else 0.0 ) interval_tt_mean = ( float(np.mean(self._interval_mainline_travel_times)) if self._interval_mainline_travel_times else np.nan ) cumulative_tt_mean = ( float(np.mean(self._completed_mainline_travel_times)) if self._completed_mainline_travel_times else np.nan ) info["mainline_completed_count"] = len(self._interval_mainline_travel_times) info["mainline_interval_travel_time_mean_s"] = interval_tt_mean info["mainline_travel_time_cumulative_mean_s"] = cumulative_tt_mean info["mainline_active_vehicle_count"] = len(self._active_mainline_vehicle_ids) try: info["sim_time"] = traci.simulation.getTime() except Exception: info["sim_time"] = 0.0 return info def _collect_metrics( self, detector_data: Tuple[List[float], List[float], List[int], List[float]], ) -> Dict: return self._collect_runtime_metrics(detector_data) def _calculate_reward(self, info: Dict) -> float: return self.reward_calculator.calculate( info=info, current_edge_speeds=self.current_edge_speeds, prev_edge_speeds=self._prev_edge_speeds, episode_index=self._episode_count, )