Skip to content

Commit 87751da

Browse files
committed
Fix bc_loss calculation in ReBRAC
1 parent 3433de5 commit 87751da

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

d3rlpy/algos/qlearning/torch/rebrac_impl.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ def compute_actor_loss(
5656
reduction="min",
5757
)
5858
lam = 1 / (q_t.abs().mean()).detach()
59-
bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean()
59+
bc_loss = ((batch.actions - action.squashed_mu) ** 2).sum(
60+
dim=1, keepdim=True
61+
)
6062
return TD3PlusBCActorLoss(
61-
actor_loss=lam * -q_t.mean() + self._actor_beta * bc_loss,
62-
bc_loss=bc_loss,
63+
actor_loss=(lam * -q_t + self._actor_beta * bc_loss).mean(),
64+
bc_loss=bc_loss.mean(),
6365
)
6466

6567
def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:

0 commit comments

Comments
 (0)