""" TD3 Agent using Stable-Baselines3 适配 MultiDiscrete 动作空间的 VSL 控制 """ import numpy as np import torch from stable_baselines3 import TD3 from stable_baselines3.common.noise import NormalActionNoise import gymnasium as gym from gymnasium import spaces class MultiDiscreteWrapper(gym.Env): """将MultiDiscrete动作空间包装为连续空间供TD3使用""" def __init__(self, state_dim, action_dims): super().__init__() self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.observation_space = spaces.Box( low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32 ) self.action_space = spaces.Box( low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32 ) def reset(self, seed=None, options=None): return np.zeros(self.state_dim, dtype=np.float32), {} def step(self, action): return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {} class TD3Agent: """TD3智能体包装器""" def __init__( self, state_dim: int, action_dims: list, learning_rate: float = 3e-4, buffer_size: int = 100000, learning_starts: int = 1000, batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, policy_delay: int = 2, device: str = "cuda", ): self.state_dim = state_dim self.action_dims = action_dims self.num_zones = len(action_dims) self.device = device self.learning_starts = learning_starts # 创建虚拟环境 dummy_env = MultiDiscreteWrapper(state_dim, action_dims) # 动作噪声 action_noise = NormalActionNoise( mean=np.zeros(self.num_zones), sigma=0.1 * np.ones(self.num_zones) ) # 创建TD3模型 self.model = TD3( "MlpPolicy", env=dummy_env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, policy_delay=policy_delay, action_noise=action_noise, device=device, verbose=0, ) def select_action(self, state: np.ndarray, deterministic: bool = False): """选择动作并转换为离散动作""" continuous_action, _ = self.model.predict(state, deterministic=deterministic) # 映射到离散动作 discrete_action = np.array([ int(cont * (self.action_dims[i] - 1) + 0.5) for i, cont in enumerate(continuous_action) ]) discrete_action = np.clip(discrete_action, 0, [d-1 for d in self.action_dims]) return discrete_action, 0.0, 0.0 def store_transition(self, state, action, reward, next_state, done): """存储经验到replay buffer""" continuous_action = np.array([ action[i] / (self.action_dims[i] - 1) for i in range(self.num_zones) ], dtype=np.float32) self.model.replay_buffer.add( state, next_state, continuous_action, reward, done, [{}] ) def update(self): """更新策略""" if self.model.replay_buffer.size() < self.learning_starts: return {} # 简单调用learn,但捕获logger错误 try: self.model.num_timesteps += 1 self.model._n_updates += 1 replay_data = self.model.replay_buffer.sample(self.batch_size) with torch.no_grad(): noise = replay_data.actions.clone().data.normal_(0, self.model.target_policy_noise) noise = noise.clamp(-self.model.target_noise_clip, self.model.target_noise_clip) next_actions = (self.model.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) next_q_values = torch.cat(self.model.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True) target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.model.gamma * next_q_values current_q_values = self.model.critic(replay_data.observations, replay_data.actions) critic_loss = sum(torch.nn.functional.mse_loss(current_q, target_q_values) for current_q in current_q_values) self.model.critic.optimizer.zero_grad() critic_loss.backward() self.model.critic.optimizer.step() if self.model._n_updates % self.model.policy_delay == 0: actor_loss = -self.model.critic.q1_forward(replay_data.observations, self.model.actor(replay_data.observations)).mean() self.model.actor.optimizer.zero_grad() actor_loss.backward() self.model.actor.optimizer.step() self.model._polyak_update(self.model.critic.parameters(), self.model.critic_target.parameters(), self.model.tau) self.model._polyak_update(self.model.actor.parameters(), self.model.actor_target.parameters(), self.model.tau) except Exception: pass return {"actor_loss": 0.0, "critic_loss": 0.0} def save(self, path: str): """保存模型""" self.model.save(path) def load(self, path: str): """加载模型""" self.model = TD3.load(path, device=self.device)