ctm-dqn/demand_loader.py

124 lines
3.7 KiB
Python

"""
Demand loader for reading traffic demand from CSV files.
"""
import numpy as np
import pandas as pd
from typing import Optional, Union
class DemandLoader:
"""Load and manage traffic demand data from CSV files."""
def __init__(
self,
csv_path: Optional[str] = None,
time_step: float = 10.0,
demand_column: str = "demand",
time_column: str = "time",
):
"""
Initialize demand loader.
Args:
csv_path: Path to CSV file with demand data
time_step: Simulation time step (seconds)
demand_column: Name of the column containing demand values (vehicles/hour)
time_column: Name of the column containing time values (optional)
"""
self.csv_path = csv_path
self.time_step = time_step
self.demand_column = demand_column
self.time_column = time_column
self.demand_data = None
self.current_index = 0
if csv_path is not None:
self.load_csv(csv_path)
def load_csv(self, csv_path: str):
"""
Load demand data from CSV file.
CSV format options:
1. Single column with demand values (one per time step)
2. Two columns: time and demand
Args:
csv_path: Path to CSV file
"""
try:
df = pd.read_csv(csv_path)
if self.demand_column in df.columns:
self.demand_data = df[self.demand_column].values
elif len(df.columns) == 1:
# Single column, assume it's demand
self.demand_data = df.iloc[:, 0].values
else:
raise ValueError(
f"Could not find demand column '{self.demand_column}' in CSV file"
)
# Validate data
if len(self.demand_data) == 0:
raise ValueError("CSV file contains no data")
if np.any(self.demand_data < 0):
raise ValueError("Demand values must be non-negative")
print(f"Loaded {len(self.demand_data)} demand values from {csv_path}")
print(f"Demand range: [{np.min(self.demand_data):.1f}, {np.max(self.demand_data):.1f}] veh/h")
except Exception as e:
raise RuntimeError(f"Failed to load CSV file: {e}")
def reset(self):
"""Reset to the beginning of the demand sequence."""
self.current_index = 0
def get_demand(self, step: int) -> float:
"""
Get demand for a specific time step.
Args:
step: Time step index
Returns:
Demand in vehicles/hour
"""
if self.demand_data is None:
raise RuntimeError("No demand data loaded. Call load_csv() first.")
# Use modulo to loop through data if episode is longer than data
index = step % len(self.demand_data)
return float(self.demand_data[index])
def get_current_demand(self) -> float:
"""
Get demand for current time step and advance index.
Returns:
Demand in vehicles/hour
"""
demand = self.get_demand(self.current_index)
self.current_index += 1
return demand
def __len__(self) -> int:
"""Return number of demand data points."""
return len(self.demand_data) if self.demand_data is not None else 0
def get_statistics(self) -> dict:
"""Get statistics of the demand data."""
if self.demand_data is None:
return
return {
"mean": np.mean(self.demand_data),
"std": np.std(self.demand_data),
"min": np.min(self.demand_data),
"max": np.max(self.demand_data),
"length": len(self.demand_data),
}