ctm-dqn/td3_agent.py

155 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TD3 Agent using Stable-Baselines3
适配 MultiDiscrete 动作空间的 VSL 控制
"""
import numpy as np
import torch
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
import gymnasium as gym
from gymnasium import spaces
class MultiDiscreteWrapper(gym.Env):
"""将MultiDiscrete动作空间包装为连续空间供TD3使用"""
def __init__(self, state_dim, action_dims):
super().__init__()
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32
)
self.action_space = spaces.Box(
low=0.0, high=1.0, shape=(self.num_zones,), dtype=np.float32
)
def reset(self, seed=None, options=None):
return np.zeros(self.state_dim, dtype=np.float32), {}
def step(self, action):
return np.zeros(self.state_dim, dtype=np.float32), 0.0, False, False, {}
class TD3Agent:
"""TD3智能体包装器"""
def __init__(
self,
state_dim: int,
action_dims: list,
learning_rate: float = 3e-4,
buffer_size: int = 100000,
learning_starts: int = 1000,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
policy_delay: int = 2,
device: str = "cuda",
):
self.state_dim = state_dim
self.action_dims = action_dims
self.num_zones = len(action_dims)
self.device = device
self.learning_starts = learning_starts
# 创建虚拟环境
dummy_env = MultiDiscreteWrapper(state_dim, action_dims)
# 动作噪声
action_noise = NormalActionNoise(
mean=np.zeros(self.num_zones),
sigma=0.1 * np.ones(self.num_zones)
)
# 创建TD3模型
self.model = TD3(
"MlpPolicy",
env=dummy_env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
policy_delay=policy_delay,
action_noise=action_noise,
device=device,
verbose=0,
)
def select_action(self, state: np.ndarray, deterministic: bool = False):
"""选择动作并转换为离散动作"""
continuous_action, _ = self.model.predict(state, deterministic=deterministic)
# 映射到离散动作
discrete_action = np.array([
int(cont * (self.action_dims[i] - 1) + 0.5)
for i, cont in enumerate(continuous_action)
])
discrete_action = np.clip(discrete_action, 0, [d-1 for d in self.action_dims])
return discrete_action, 0.0, 0.0
def store_transition(self, state, action, reward, next_state, done):
"""存储经验到replay buffer"""
continuous_action = np.array([
action[i] / (self.action_dims[i] - 1)
for i in range(self.num_zones)
], dtype=np.float32)
self.model.replay_buffer.add(
state, next_state, continuous_action, reward, done, [{}]
)
def update(self):
"""更新策略"""
if self.model.replay_buffer.size() < self.learning_starts:
return {}
# 简单调用learn但捕获logger错误
try:
self.model.num_timesteps += 1
self.model._n_updates += 1
replay_data = self.model.replay_buffer.sample(self.batch_size)
with torch.no_grad():
noise = replay_data.actions.clone().data.normal_(0, self.model.target_policy_noise)
noise = noise.clamp(-self.model.target_noise_clip, self.model.target_noise_clip)
next_actions = (self.model.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
next_q_values = torch.cat(self.model.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.model.gamma * next_q_values
current_q_values = self.model.critic(replay_data.observations, replay_data.actions)
critic_loss = sum(torch.nn.functional.mse_loss(current_q, target_q_values) for current_q in current_q_values)
self.model.critic.optimizer.zero_grad()
critic_loss.backward()
self.model.critic.optimizer.step()
if self.model._n_updates % self.model.policy_delay == 0:
actor_loss = -self.model.critic.q1_forward(replay_data.observations, self.model.actor(replay_data.observations)).mean()
self.model.actor.optimizer.zero_grad()
actor_loss.backward()
self.model.actor.optimizer.step()
self.model._polyak_update(self.model.critic.parameters(), self.model.critic_target.parameters(), self.model.tau)
self.model._polyak_update(self.model.actor.parameters(), self.model.actor_target.parameters(), self.model.tau)
except Exception:
pass
return {"actor_loss": 0.0, "critic_loss": 0.0}
def save(self, path: str):
"""保存模型"""
self.model.save(path)
def load(self, path: str):
"""加载模型"""
self.model = TD3.load(path, device=self.device)