ctm-dqn/envs/network_parser.py

283 lines
9.9 KiB
Python

"""
SUMO 网络拓扑解析器
解析 SUMO 网络文件和感应线圈定义文件,构建:
1. Edge 顺序和长度信息
2. 检测器到 edge/lane/position 的映射
3. Zone 分组结构
"""
import os
import sys
import xml.etree.ElementTree as ET
import re
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
# 确保 traci/sumolib 可导入
try:
import sumo
_tools = os.path.join(sumo.SUMO_HOME, "tools")
if _tools not in sys.path:
sys.path.insert(0, _tools)
except ImportError:
try:
# 兼容服务器环境
sys.path.insert(0, "/workspace/ctm-dqn/.venv/lib/python3.12/site-packages/sumo/tools")
import sumo
except:
pass
@dataclass
class DetectorInfo:
"""单个感应线圈检测器信息"""
det_id: str
edge: str
lane_id: str # 完整 lane ID, e.g. "G1523_AM3_4.1_1"
lane_idx: int # lane 编号 (0=应急车道, 1,2,3=交通车道)
position: float # 在 edge 上的位置 (m)
pos_index: int # 检测器序号 (metrics_0, metrics_1, ...)
@dataclass
class EdgeInfo:
"""单条 edge 的信息"""
edge_id: str
length: float = 0.0
speed_limit: float = 30.0 # m/s
total_lanes: int = 0
traffic_lane_indices: List[int] = field(default_factory=list) # 排除应急车道后的 lane 编号
detector_positions: List[float] = field(default_factory=list) # 去重后的检测器位置列表(升序)
detectors: Dict[Tuple[int, int], str] = field(default_factory=dict) # (lane_idx, pos_index) -> det_id
class SUMONetworkParser:
"""SUMO 网络和检测器拓扑解析器"""
def __init__(self, detector_add_file: str, net_file: Optional[str] = None):
"""
Args:
detector_add_file: 感应线圈定义文件路径 (metrics_il.add.xml)
net_file: SUMO 网络文件路径 (modified.net.xml), 可选
"""
self.detector_add_file = detector_add_file
self.net_file = net_file
# 解析结果
self.detectors: Dict[str, DetectorInfo] = {} # det_id -> DetectorInfo
self.edges: List[str] = [] # 按行驶顺序排列
self.edge_info: Dict[str, EdgeInfo] = {} # edge_id -> EdgeInfo
self.edge_cumulative_offset: Dict[str, float] = {} # edge -> 起始累积距离
# 派生属性
self.max_traffic_lanes: int = 0
self.max_detectors_per_edge: int = 0 # 单条 edge 上最大检测器位置数
# 执行解析
self._parse_detectors()
if net_file:
self._enrich_from_net()
self._refresh_lane_metadata_from_net()
self._compute_cumulative_offsets()
def _parse_detectors(self):
"""解析 metrics_il.add.xml, 构建检测器到 edge 的映射"""
tree = ET.parse(self.detector_add_file)
root = tree.getroot()
edge_detector_map = defaultdict(list) # edge -> [DetectorInfo]
edge_order = [] # 按出现顺序记录 edge
seen_edges = set()
for il in root.findall("inductionLoop"):
det_id = il.get("id")
lane_id = il.get("lane")
pos = float(il.get("pos"))
# 解析 lane ID: 最后一个 _ 后面是 lane_idx
last_us = lane_id.rfind("_")
edge_id = lane_id[:last_us]
lane_idx = int(lane_id[last_us + 1:])
# 解析检测器 ID 中的 pos_index: ..._metrics_N
m = re.search(r"_metrics_(\d+)$", det_id)
pos_index = int(m.group(1)) if m else 0
info = DetectorInfo(
det_id=det_id,
edge=edge_id,
lane_id=lane_id,
lane_idx=lane_idx,
position=pos,
pos_index=pos_index,
)
self.detectors[det_id] = info
edge_detector_map[edge_id].append(info)
if edge_id not in seen_edges:
seen_edges.add(edge_id)
edge_order.append(edge_id)
self.edges = edge_order
# 构建 EdgeInfo
for edge_id in self.edges:
dets = edge_detector_map[edge_id]
lane_indices = sorted(set(d.lane_idx for d in dets))
# 排除应急车道 (lane 0): 通常应急车道不会有车辆通过
traffic_lanes = [l for l in lane_indices if l > 0]
if not traffic_lanes:
traffic_lanes = lane_indices # 如果只有 lane 0 则保留
# 检测器位置(去重,升序)
positions = sorted(set(d.position for d in dets))
# 构建 (lane_idx, pos_index) -> det_id 的映射
det_map = {}
for d in dets:
det_map[(d.lane_idx, d.pos_index)] = d.det_id
# 估算 edge 长度: 最大检测器位置 + 100m
est_length = max(d.position for d in dets) + 100.0
# 估算限速: 将在 _enrich_from_net 中覆盖
self.edge_info[edge_id] = EdgeInfo(
edge_id=edge_id,
length=est_length,
total_lanes=max(lane_indices) + 1,
traffic_lane_indices=traffic_lanes,
detector_positions=positions,
detectors=det_map,
)
# 计算最大值
self.max_traffic_lanes = max(
len(ei.traffic_lane_indices) for ei in self.edge_info.values()
)
self.max_detectors_per_edge = max(
len(ei.detector_positions) for ei in self.edge_info.values()
)
def _enrich_from_net(self):
"""从 net.xml 补充 edge 长度和限速信息"""
try:
import sumolib
net = sumolib.net.readNet(self.net_file)
for edge_id in self.edges:
try:
edge = net.getEdge(edge_id)
self.edge_info[edge_id].length = edge.getLength()
lanes = edge.getLanes()
if lanes:
self.edge_info[edge_id].speed_limit = max(
l.getSpeed() for l in lanes
)
except Exception:
pass # 保留估算值
except ImportError:
pass # sumolib 不可用时保留估算值
def _compute_cumulative_offsets(self):
"""计算每条 edge 的累积起始距离"""
cumulative = 0.0
for edge_id in self.edges:
self.edge_cumulative_offset[edge_id] = cumulative
cumulative += self.edge_info[edge_id].length
def _refresh_lane_metadata_from_net(self):
"""Use lane permissions from net.xml instead of assuming lane 0 is emergency."""
net_tree = ET.parse(self.net_file)
net_root = net_tree.getroot()
for edge in net_root.findall("edge"):
edge_id = edge.get("id")
if edge_id not in self.edge_info:
continue
lane_nodes = edge.findall("lane")
if not lane_nodes:
continue
traffic_lanes = []
for lane in lane_nodes:
lane_idx = int(lane.get("index", "0"))
allow = (lane.get("allow") or "").strip()
allow_tokens = {token for token in allow.split() if token}
if allow_tokens != {"emergency"}:
traffic_lanes.append(lane_idx)
if not traffic_lanes:
traffic_lanes = list(range(len(lane_nodes)))
self.edge_info[edge_id].traffic_lane_indices = sorted(traffic_lanes)
self.edge_info[edge_id].total_lanes = len(lane_nodes)
@property
def total_length(self) -> float:
"""监控路段总长度 (m)"""
return sum(ei.length for ei in self.edge_info.values())
@property
def num_edges(self) -> int:
return len(self.edges)
def get_detector_ids_for_edge_lane(
self, edge_id: str, lane_idx: int
) -> List[str]:
"""返回指定 edge 和 lane 上的检测器 ID 列表(按位置排序)"""
ei = self.edge_info[edge_id]
result = []
for pi in range(len(ei.detector_positions)):
det_id = ei.detectors.get((lane_idx, pi))
if det_id:
result.append(det_id)
return result
def get_all_detector_ids(self) -> List[str]:
"""返回所有检测器 ID 列表"""
return list(self.detectors.keys())
def get_state_dimensions(self, num_features: int = 2) -> int:
"""
计算状态向量总维度
state = [
对每条 edge (num_edges):
对每个 cell (padded 到 max_detectors_per_edge):
对每条 traffic lane (padded 到 max_traffic_lanes):
[speed, occupancy] (num_features)
+ 每条 edge 的当前限速归一化值 (num_edges)
]
"""
per_edge_dim = self.max_detectors_per_edge * self.max_traffic_lanes * num_features
return self.num_edges * per_edge_dim + self.num_edges
def summary(self):
"""打印拓扑摘要"""
print(f"=== SUMO Network Topology ===")
print(f"Edges: {self.num_edges}")
print(f"Max traffic lanes: {self.max_traffic_lanes}")
print(f"Max detectors per edge: {self.max_detectors_per_edge}")
print(f"Total length: {self.total_length:.0f} m ({self.total_length/1000:.1f} km)")
print(f"Total detectors: {len(self.detectors)}")
print(f"State dimension: {self.get_state_dimensions()}")
print()
for edge_id in self.edges:
ei = self.edge_info[edge_id]
print(
f" {edge_id:25s} len={ei.length:8.1f}m "
f"lanes={ei.traffic_lane_indices} "
f"dets={len(ei.detector_positions):2d} "
f"v_max={ei.speed_limit*3.6:.0f}km/h"
)
if __name__ == "__main__":
parser = SUMONetworkParser(
detector_add_file="sumo_resource/metrics_il.add.xml",
net_file="sumo_resource/modified.net.xml",
)
parser.summary()