Skip to content

Commit 50edc8a

Browse files
committed
add a foreach version of lion
1 parent 77b6aa2 commit 50edc8a

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

lion_pytorch/foreach.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
from typing import Tuple, Callable
3+
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
# functions
8+
9+
def exists(val):
10+
return val is not None
11+
12+
# class
13+
14+
class Lion(Optimizer):
15+
def __init__(
16+
self,
17+
params,
18+
lr: float = 1e-4,
19+
betas: Tuple[float, float] = (0.9, 0.99),
20+
weight_decay: float = 0.0
21+
):
22+
assert lr > 0.
23+
assert all([0. <= beta <= 1. for beta in betas])
24+
assert all([hasattr(torch, attr) for attr in ('_foreach_mul_', '_foreach_add_', '_foreach_sign_', '_foreach_lerp_')]), 'this version of torch does not have the prerequisite foreach functions'
25+
26+
defaults = dict(
27+
lr = lr,
28+
betas = betas,
29+
weight_decay = weight_decay
30+
)
31+
32+
super().__init__(params, defaults)
33+
34+
@torch.no_grad()
35+
def step(
36+
self,
37+
closure: Callable | None = None
38+
):
39+
40+
loss = None
41+
if exists(closure):
42+
with torch.enable_grad():
43+
loss = closure()
44+
45+
for group in self.param_groups:
46+
47+
lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
48+
49+
params = []
50+
grads = []
51+
exp_avgs = []
52+
53+
for p in filter(lambda p: exists(p.grad), group['params']):
54+
55+
grad, state = p.grad, self.state[p]
56+
57+
# init state - exponential moving average of gradient values
58+
59+
if len(state) == 0:
60+
state['exp_avg'] = torch.zeros_like(p)
61+
62+
exp_avg = state['exp_avg']
63+
64+
params.append(p)
65+
grads.append(grad)
66+
exp_avgs.append(exp_avg)
67+
68+
# stepweight decay
69+
70+
torch._foreach_mul_(params, 1. - lr * wd)
71+
72+
# weight update
73+
74+
updates = [t.clone() for t in exp_avgs]
75+
torch._foreach_lerp_(updates, grads, 1. - beta1)
76+
torch._foreach_sign_(updates)
77+
78+
torch._foreach_add_(params, updates)
79+
80+
# decay momentum running average
81+
82+
torch._foreach_lerp_(exp_avgs, grads, 1. - beta2)
83+
84+
return loss

lion_pytorch/lion_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Tuple, Optional, Callable
1+
from __future__ import annotations
2+
from typing import Tuple, Callable
23

34
import torch
45
from torch.optim.optimizer import Optimizer
@@ -55,7 +56,7 @@ def __init__(
5556
@torch.no_grad()
5657
def step(
5758
self,
58-
closure: Optional[Callable] = None
59+
closure: Callable | None = None
5960
):
6061

6162
loss = None

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.4',
6+
version = '0.2.0',
77
license='MIT',
88
description = 'Lion Optimizer - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)