Skip to content

Commit 2cfbf20

Browse files
authored
RAdam fix for issue #96. (#103)
* Modify RAdam code to follow up the original RAdam repo: - Different param groups should use different buffers. * Wrap long code sequences to fit guidelines
1 parent df65965 commit 2cfbf20

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torch_optimizer/radam.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,20 @@ def __init__(
6262
'Invalid weight_decay value: {}'.format(weight_decay)
6363
)
6464

65-
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
66-
self._buffer = [[None, None, None] for ind in range(10)]
65+
if isinstance(params, (list, tuple)) and \
66+
len(params) > 0 and isinstance(params[0], dict):
67+
for param in params:
68+
if 'betas' in param and (param['betas'][0] != betas[0] or
69+
param['betas'][1] != betas[1]):
70+
param['buffer'] = [[None, None, None] for _ in range(10)]
71+
72+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
73+
buffer=[[None, None, None] for _ in range(10)])
6774
super(RAdam, self).__init__(params, defaults)
6875

76+
def __setstate__(self, state):
77+
super(RAdam, self).__setstate__(state)
78+
6979
def step(self, closure: OptLossClosure = None) -> OptFloat:
7080
r"""Performs a single optimization step.
7181
@@ -114,7 +124,7 @@ def step(self, closure: OptLossClosure = None) -> OptFloat:
114124
exp_avg.mul_(beta1).add_(1 - beta1, grad)
115125

116126
state['step'] += 1
117-
buffered = self._buffer[int(state['step'] % 10)]
127+
buffered = group['buffer'][int(state['step'] % 10)]
118128
if state['step'] == buffered[0]:
119129
N_sma, step_size = buffered[1], buffered[2]
120130
else:

0 commit comments

Comments
 (0)