再次修复

This commit is contained in:
Zihan Ye 2026-04-01 02:41:05 +08:00
parent 5471098e77
commit 2b772f9de0
1 changed files with 7 additions and 8 deletions

View File

@ -109,15 +109,12 @@ class TD3Agent:
if self.model.replay_buffer.size() < self.learning_starts:
return {}
# 手动更新而不是调用train()
self.model._n_updates += 1
gradient_steps = 1
# 简单调用learn但捕获logger错误
try:
self.model.num_timesteps += 1
self.model._n_updates += 1
for _ in range(gradient_steps):
self.model._update_learning_rate(self.model.actor.optimizer)
self.model._update_learning_rate(self.model.critic.optimizer)
replay_data = self.model.replay_buffer.sample(self.model.batch_size)
replay_data = self.model.replay_buffer.sample(self.batch_size)
with torch.no_grad():
noise = replay_data.actions.clone().data.normal_(0, self.model.target_policy_noise)
@ -143,6 +140,8 @@ class TD3Agent:
self.model._polyak_update(self.model.critic.parameters(), self.model.critic_target.parameters(), self.model.tau)
self.model._polyak_update(self.model.actor.parameters(), self.model.actor_target.parameters(), self.model.tau)
except Exception:
pass
return {"actor_loss": 0.0, "critic_loss": 0.0}