139 lines
4.0 KiB
Python
139 lines
4.0 KiB
Python
"""
|
|
Multiprocessing-based parallel environment for true parallelization.
|
|
"""
|
|
import numpy as np
|
|
from typing import List, Tuple, Dict
|
|
from multiprocessing import Process, Pipe
|
|
from environment import TrafficEnvironment
|
|
|
|
|
|
def worker(remote, parent_remote, config):
|
|
"""
|
|
Worker process for running environment.
|
|
|
|
Args:
|
|
remote: Child end of pipe
|
|
parent_remote: Parent end of pipe (closed in child)
|
|
config: Environment configuration
|
|
"""
|
|
parent_remote.close()
|
|
env = TrafficEnvironment(config)
|
|
|
|
while True:
|
|
try:
|
|
cmd, data = remote.recv()
|
|
|
|
if cmd == 'step':
|
|
state, reward, done, info = env.step(data)
|
|
if done:
|
|
state = env.reset()
|
|
remote.send((state, reward, done, info))
|
|
|
|
elif cmd == 'reset':
|
|
state = env.reset()
|
|
remote.send(state)
|
|
|
|
elif cmd == 'get_metrics':
|
|
remote.send(env.episode_metrics)
|
|
|
|
elif cmd == 'close':
|
|
remote.close()
|
|
break
|
|
|
|
except EOFError:
|
|
break
|
|
|
|
|
|
class ParallelEnvironment:
|
|
"""Multiprocessing-based parallel environment wrapper."""
|
|
|
|
def __init__(self, config: dict, num_envs: int = 4):
|
|
"""
|
|
Initialize parallel environment with multiprocessing.
|
|
|
|
Args:
|
|
config: Configuration dictionary
|
|
num_envs: Number of parallel environments
|
|
"""
|
|
self.num_envs = num_envs
|
|
self.config = config
|
|
|
|
# Create pipes and processes
|
|
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(num_envs)])
|
|
self.processes = []
|
|
|
|
for work_remote, remote in zip(self.work_remotes, self.remotes):
|
|
proc = Process(
|
|
target=worker,
|
|
args=(work_remote, remote, config)
|
|
)
|
|
proc.daemon = True
|
|
proc.start()
|
|
self.processes.append(proc)
|
|
work_remote.close()
|
|
|
|
# Get dimensions from config
|
|
temp_env = TrafficEnvironment(config)
|
|
self.state_dim = temp_env.state_dim
|
|
self.action_dim = temp_env.action_dim
|
|
|
|
print(f"Created {num_envs} parallel environments with multiprocessing")
|
|
|
|
def reset(self) -> np.ndarray:
|
|
"""
|
|
Reset all environments in parallel.
|
|
|
|
Returns:
|
|
Array of initial states, shape: (num_envs, state_dim)
|
|
"""
|
|
for remote in self.remotes:
|
|
remote.send(('reset', None))
|
|
|
|
states = [remote.recv() for remote in self.remotes]
|
|
return np.array(states)
|
|
|
|
def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]:
|
|
"""
|
|
Step all environments in parallel.
|
|
|
|
Args:
|
|
actions: Array of actions, shape: (num_envs,)
|
|
|
|
Returns:
|
|
states: Array of next states, shape: (num_envs, state_dim)
|
|
rewards: Array of rewards, shape: (num_envs,)
|
|
dones: Array of done flags, shape: (num_envs,)
|
|
infos: List of info dictionaries
|
|
"""
|
|
# Send actions to all workers
|
|
for remote, action in zip(self.remotes, actions):
|
|
remote.send(('step', action))
|
|
|
|
# Receive results from all workers
|
|
results = [remote.recv() for remote in self.remotes]
|
|
states, rewards, dones, infos = zip(*results)
|
|
|
|
return np.array(states), np.array(rewards), np.array(dones), list(infos)
|
|
|
|
def get_episode_metrics(self) -> List[List[Dict]]:
|
|
"""
|
|
Get episode metrics from all environments.
|
|
|
|
Returns:
|
|
List of episode metrics for each environment
|
|
"""
|
|
for remote in self.remotes:
|
|
remote.send(('get_metrics', None))
|
|
|
|
metrics = [remote.recv() for remote in self.remotes]
|
|
return metrics
|
|
|
|
def close(self):
|
|
"""Close all worker processes."""
|
|
for remote in self.remotes:
|
|
remote.send(('close', None))
|
|
|
|
for proc in self.processes:
|
|
proc.join()
|
|
|