Skip to content

Commit 7eaa2f3

Browse files
committed
Fix warnings in sgdp
1 parent 75d625a commit 7eaa2f3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_optimizer/sgdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def step(self, closure: OptLossClosure = None) -> OptFloat:
152152

153153
# SGD
154154
buf = state['momentum']
155-
buf.mul_(momentum).add_(1 - dampening, grad)
155+
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
156156
if nesterov:
157157
d_p = grad + momentum * buf
158158
else:
@@ -181,6 +181,6 @@ def step(self, closure: OptLossClosure = None) -> OptFloat:
181181
)
182182

183183
# Step
184-
p.data.add_(-group['lr'], d_p)
184+
p.data.add_(d_p, alpha=-group['lr'])
185185

186186
return loss

0 commit comments

Comments
 (0)