""" Multiprocessing-based parallel environment for true parallelization. """ import numpy as np from typing import List, Tuple, Dict from multiprocessing import Process, Pipe from environment import TrafficEnvironment def worker(remote, parent_remote, config): """ Worker process for running environment. Args: remote: Child end of pipe parent_remote: Parent end of pipe (closed in child) config: Environment configuration """ parent_remote.close() env = TrafficEnvironment(config) while True: try: cmd, data = remote.recv() if cmd == 'step': state, reward, done, info = env.step(data) if done: state = env.reset() remote.send((state, reward, done, info)) elif cmd == 'reset': state = env.reset() remote.send(state) elif cmd == 'get_metrics': remote.send(env.episode_metrics) elif cmd == 'close': remote.close() break except EOFError: break class ParallelEnvironment: """Multiprocessing-based parallel environment wrapper.""" def __init__(self, config: dict, num_envs: int = 4): """ Initialize parallel environment with multiprocessing. Args: config: Configuration dictionary num_envs: Number of parallel environments """ self.num_envs = num_envs self.config = config # Create pipes and processes self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(num_envs)]) self.processes = [] for work_remote, remote in zip(self.work_remotes, self.remotes): proc = Process( target=worker, args=(work_remote, remote, config) ) proc.daemon = True proc.start() self.processes.append(proc) work_remote.close() # Get dimensions from config temp_env = TrafficEnvironment(config) self.state_dim = temp_env.state_dim self.action_dim = temp_env.action_dim print(f"Created {num_envs} parallel environments with multiprocessing") def reset(self) -> np.ndarray: """ Reset all environments in parallel. Returns: Array of initial states, shape: (num_envs, state_dim) """ for remote in self.remotes: remote.send(('reset', None)) states = [remote.recv() for remote in self.remotes] return np.array(states) def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]: """ Step all environments in parallel. Args: actions: Array of actions, shape: (num_envs,) Returns: states: Array of next states, shape: (num_envs, state_dim) rewards: Array of rewards, shape: (num_envs,) dones: Array of done flags, shape: (num_envs,) infos: List of info dictionaries """ # Send actions to all workers for remote, action in zip(self.remotes, actions): remote.send(('step', action)) # Receive results from all workers results = [remote.recv() for remote in self.remotes] states, rewards, dones, infos = zip(*results) return np.array(states), np.array(rewards), np.array(dones), list(infos) def get_episode_metrics(self) -> List[List[Dict]]: """ Get episode metrics from all environments. Returns: List of episode metrics for each environment """ for remote in self.remotes: remote.send(('get_metrics', None)) metrics = [remote.recv() for remote in self.remotes] return metrics def close(self): """Close all worker processes.""" for remote in self.remotes: remote.send(('close', None)) for proc in self.processes: proc.join()