We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3433de5 commit 87751daCopy full SHA for 87751da
d3rlpy/algos/qlearning/torch/rebrac_impl.py
@@ -56,10 +56,12 @@ def compute_actor_loss(
56
reduction="min",
57
)
58
lam = 1 / (q_t.abs().mean()).detach()
59
- bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean()
+ bc_loss = ((batch.actions - action.squashed_mu) ** 2).sum(
60
+ dim=1, keepdim=True
61
+ )
62
return TD3PlusBCActorLoss(
- actor_loss=lam * -q_t.mean() + self._actor_beta * bc_loss,
- bc_loss=bc_loss,
63
+ actor_loss=(lam * -q_t + self._actor_beta * bc_loss).mean(),
64
+ bc_loss=bc_loss.mean(),
65
66
67
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
0 commit comments