ctm-dqn/parallel_env.py

139 lines
4.0 KiB
Python

"""
Multiprocessing-based parallel environment for true parallelization.
"""
import numpy as np
from typing import List, Tuple, Dict
from multiprocessing import Process, Pipe
from environment import TrafficEnvironment
def worker(remote, parent_remote, config):
"""
Worker process for running environment.
Args:
remote: Child end of pipe
parent_remote: Parent end of pipe (closed in child)
config: Environment configuration
"""
parent_remote.close()
env = TrafficEnvironment(config)
while True:
try:
cmd, data = remote.recv()
if cmd == 'step':
state, reward, done, info = env.step(data)
if done:
state = env.reset()
remote.send((state, reward, done, info))
elif cmd == 'reset':
state = env.reset()
remote.send(state)
elif cmd == 'get_metrics':
remote.send(env.episode_metrics)
elif cmd == 'close':
remote.close()
break
except EOFError:
break
class ParallelEnvironment:
"""Multiprocessing-based parallel environment wrapper."""
def __init__(self, config: dict, num_envs: int = 4):
"""
Initialize parallel environment with multiprocessing.
Args:
config: Configuration dictionary
num_envs: Number of parallel environments
"""
self.num_envs = num_envs
self.config = config
# Create pipes and processes
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(num_envs)])
self.processes = []
for work_remote, remote in zip(self.work_remotes, self.remotes):
proc = Process(
target=worker,
args=(work_remote, remote, config)
)
proc.daemon = True
proc.start()
self.processes.append(proc)
work_remote.close()
# Get dimensions from config
temp_env = TrafficEnvironment(config)
self.state_dim = temp_env.state_dim
self.action_dim = temp_env.action_dim
print(f"Created {num_envs} parallel environments with multiprocessing")
def reset(self) -> np.ndarray:
"""
Reset all environments in parallel.
Returns:
Array of initial states, shape: (num_envs, state_dim)
"""
for remote in self.remotes:
remote.send(('reset', None))
states = [remote.recv() for remote in self.remotes]
return np.array(states)
def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]:
"""
Step all environments in parallel.
Args:
actions: Array of actions, shape: (num_envs,)
Returns:
states: Array of next states, shape: (num_envs, state_dim)
rewards: Array of rewards, shape: (num_envs,)
dones: Array of done flags, shape: (num_envs,)
infos: List of info dictionaries
"""
# Send actions to all workers
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
# Receive results from all workers
results = [remote.recv() for remote in self.remotes]
states, rewards, dones, infos = zip(*results)
return np.array(states), np.array(rewards), np.array(dones), list(infos)
def get_episode_metrics(self) -> List[List[Dict]]:
"""
Get episode metrics from all environments.
Returns:
List of episode metrics for each environment
"""
for remote in self.remotes:
remote.send(('get_metrics', None))
metrics = [remote.recv() for remote in self.remotes]
return metrics
def close(self):
"""Close all worker processes."""
for remote in self.remotes:
remote.send(('close', None))
for proc in self.processes:
proc.join()