再次修复
This commit is contained in:
parent
5471098e77
commit
2b772f9de0
15
td3_agent.py
15
td3_agent.py
|
|
@ -109,15 +109,12 @@ class TD3Agent:
|
|||
if self.model.replay_buffer.size() < self.learning_starts:
|
||||
return {}
|
||||
|
||||
# 手动更新而不是调用train()
|
||||
self.model._n_updates += 1
|
||||
gradient_steps = 1
|
||||
# 简单调用learn,但捕获logger错误
|
||||
try:
|
||||
self.model.num_timesteps += 1
|
||||
self.model._n_updates += 1
|
||||
|
||||
for _ in range(gradient_steps):
|
||||
self.model._update_learning_rate(self.model.actor.optimizer)
|
||||
self.model._update_learning_rate(self.model.critic.optimizer)
|
||||
|
||||
replay_data = self.model.replay_buffer.sample(self.model.batch_size)
|
||||
replay_data = self.model.replay_buffer.sample(self.batch_size)
|
||||
|
||||
with torch.no_grad():
|
||||
noise = replay_data.actions.clone().data.normal_(0, self.model.target_policy_noise)
|
||||
|
|
@ -143,6 +140,8 @@ class TD3Agent:
|
|||
|
||||
self.model._polyak_update(self.model.critic.parameters(), self.model.critic_target.parameters(), self.model.tau)
|
||||
self.model._polyak_update(self.model.actor.parameters(), self.model.actor_target.parameters(), self.model.tau)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"actor_loss": 0.0, "critic_loss": 0.0}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue