Skip to content

Commit 9a9d233

Browse files
Add SGDP optimizer (#145)
* add adamp * update test for adamp * update doc for adamp * update readme with adamp example * update method to staticmethod * add nesterov for AdamP * update viz and docs for AdamP * add check for delta & wd_ratio values * add SGDP optimizer * update tests for SGDP * update doc & readme with SGDP * Bump matplotlib from 3.2.1 to 3.2.2 Bumps [matplotlib](https://github.com/matplotlib/matplotlib) from 3.2.1 to 3.2.2. - [Release notes](https://github.com/matplotlib/matplotlib/releases) - [Commits](matplotlib/matplotlib@v3.2.1...v3.2.2) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump torch from 1.5.0 to 1.5.1 Bumps [torch](https://github.com/pytorch/pytorch) from 1.5.0 to 1.5.1. - [Release notes](https://github.com/pytorch/pytorch/releases) - [Commits](pytorch/pytorch@v1.5.0...v1.5.1) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump torchvision from 0.6.0 to 0.6.1 Bumps [torchvision](https://github.com/pytorch/vision) from 0.6.0 to 0.6.1. - [Release notes](https://github.com/pytorch/vision/releases) - [Commits](https://github.com/pytorch/vision/commits) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump mypy from 0.780 to 0.781 Bumps [mypy](https://github.com/python/mypy) from 0.780 to 0.781. - [Release notes](https://github.com/python/mypy/releases) - [Commits](python/mypy@v0.780...v0.781) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump sphinx-autodoc-typehints from 1.10.3 to 1.11.0 Bumps [sphinx-autodoc-typehints](https://github.com/agronholm/sphinx-autodoc-typehints) from 1.10.3 to 1.11.0. - [Release notes](https://github.com/agronholm/sphinx-autodoc-typehints/releases) - [Changelog](https://github.com/agronholm/sphinx-autodoc-typehints/blob/master/CHANGELOG.rst) - [Commits](agronholm/sphinx-autodoc-typehints@1.10.3...1.11.0) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump mypy from 0.781 to 0.782 Bumps [mypy](https://github.com/python/mypy) from 0.781 to 0.782. - [Release notes](https://github.com/python/mypy/releases) - [Commits](python/mypy@v0.781...v0.782) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump ipdb from 0.13.2 to 0.13.3 Bumps [ipdb](https://github.com/gotcha/ipdb) from 0.13.2 to 0.13.3. - [Release notes](https://github.com/gotcha/ipdb/releases) - [Changelog](https://github.com/gotcha/ipdb/blob/master/HISTORY.txt) - [Commits](gotcha/ipdb@0.13.2...0.13.3) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump ipython from 7.15.0 to 7.16.1 Bumps [ipython](https://github.com/ipython/ipython) from 7.15.0 to 7.16.1. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](ipython/ipython@7.15.0...7.16.1) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump numpy from 1.18.5 to 1.19.0 Bumps [numpy](https://github.com/numpy/numpy) from 1.18.5 to 1.19.0. - [Release notes](https://github.com/numpy/numpy/releases) - [Changelog](https://github.com/numpy/numpy/blob/master/doc/HOWTO_RELEASE.rst.txt) - [Commits](numpy/numpy@v1.18.5...v1.19.0) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump matplotlib from 3.2.1 to 3.2.2 Bumps [matplotlib](https://github.com/matplotlib/matplotlib) from 3.2.1 to 3.2.2. - [Release notes](https://github.com/matplotlib/matplotlib/releases) - [Commits](matplotlib/matplotlib@v3.2.1...v3.2.2) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump torch from 1.5.0 to 1.5.1 Bumps [torch](https://github.com/pytorch/pytorch) from 1.5.0 to 1.5.1. - [Release notes](https://github.com/pytorch/pytorch/releases) - [Commits](pytorch/pytorch@v1.5.0...v1.5.1) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump torchvision from 0.6.0 to 0.6.1 Bumps [torchvision](https://github.com/pytorch/vision) from 0.6.0 to 0.6.1. - [Release notes](https://github.com/pytorch/vision/releases) - [Commits](https://github.com/pytorch/vision/commits) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump mypy from 0.780 to 0.781 Bumps [mypy](https://github.com/python/mypy) from 0.780 to 0.781. - [Release notes](https://github.com/python/mypy/releases) - [Commits](python/mypy@v0.780...v0.781) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump sphinx-autodoc-typehints from 1.10.3 to 1.11.0 Bumps [sphinx-autodoc-typehints](https://github.com/agronholm/sphinx-autodoc-typehints) from 1.10.3 to 1.11.0. - [Release notes](https://github.com/agronholm/sphinx-autodoc-typehints/releases) - [Changelog](https://github.com/agronholm/sphinx-autodoc-typehints/blob/master/CHANGELOG.rst) - [Commits](agronholm/sphinx-autodoc-typehints@1.10.3...1.11.0) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump mypy from 0.781 to 0.782 Bumps [mypy](https://github.com/python/mypy) from 0.781 to 0.782. - [Release notes](https://github.com/python/mypy/releases) - [Commits](python/mypy@v0.781...v0.782) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump ipdb from 0.13.2 to 0.13.3 Bumps [ipdb](https://github.com/gotcha/ipdb) from 0.13.2 to 0.13.3. - [Release notes](https://github.com/gotcha/ipdb/releases) - [Changelog](https://github.com/gotcha/ipdb/blob/master/HISTORY.txt) - [Commits](gotcha/ipdb@0.13.2...0.13.3) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump ipython from 7.15.0 to 7.16.1 Bumps [ipython](https://github.com/ipython/ipython) from 7.15.0 to 7.16.1. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](ipython/ipython@7.15.0...7.16.1) Signed-off-by: dependabot-preview[bot] <[email protected]> * Bump numpy from 1.18.5 to 1.19.0 Bumps [numpy](https://github.com/numpy/numpy) from 1.18.5 to 1.19.0. - [Release notes](https://github.com/numpy/numpy/releases) - [Changelog](https://github.com/numpy/numpy/blob/master/doc/HOWTO_RELEASE.rst.txt) - [Commits](numpy/numpy@v1.18.5...v1.19.0) Signed-off-by: dependabot-preview[bot] <[email protected]> Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
1 parent 8452433 commit 9a9d233

15 files changed

+253
-7
lines changed

README.rst

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ Intuitively, this operation prevents the unnecessary update along the radial dir
230230
that only increases the weight norm without contributing to the loss minimization.
231231

232232
+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+
233-
| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdamP.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdamP.png |
233+
| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_AdamP.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_AdamP.png |
234234
+------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------+
235235

236236
.. code:: python
@@ -577,6 +577,36 @@ RangerVA
577577
**Reference Code**: https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
578578

579579

580+
SGDP
581+
----
582+
583+
+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+
584+
| .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rastrigin_SGDP.png | .. image:: https://raw.githubusercontent.com/jettify/pytorch-optimizer/master/docs/rosenbrock_SGDP.png |
585+
+--------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------+
586+
587+
.. code:: python
588+
589+
import torch_optimizer as optim
590+
591+
# model = ...
592+
optimizer = optim.SGDP(
593+
m.parameters(),
594+
lr= 1e-3,
595+
momentum=0,
596+
dampening=0,
597+
weight_decay=1e-2,
598+
nesterov=False,
599+
delta = 0.1,
600+
wd_ratio = 0.1
601+
)
602+
optimizer.step()
603+
604+
605+
**Paper**: *Slowing Down the Weight Norm Increase in Momentum-based Optimizers.* (2020) [https://arxiv.org/abs/2006.08217]
606+
607+
**Reference Code**: https://github.com/clovaai/AdamP
608+
609+
580610
SGDW
581611
----
582612

docs/api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ RAdam
8989
.. autoclass:: torch_optimizer.RAdam
9090
:members:
9191

92+
.. _SGDP:
93+
94+
SGDP
95+
----
96+
97+
.. autoclass:: torch_optimizer.SGDP
98+
:members:
99+
92100
.. _SGDW:
93101

94102
SGDW

docs/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ Supported Optimizers
7474
| :ref:`RangerVA` | https://arxiv.org/abs/1908.00700v2 |
7575
+-----------------+-------------------------------------------------------------------------------+
7676
| | |
77+
| :ref:`SGDP` | https://arxiv.org/abs/2006.08217 |
78+
+-----------------+-------------------------------------------------------------------------------+
79+
| | |
7780
| :ref:`SGDW` | https://arxiv.org/abs/1608.03983 |
7881
+-----------------+-------------------------------------------------------------------------------+
7982
| | |

docs/rastrigin_AdamP.png

529 KB
Loading

docs/rastrigin_SGDP.png

726 KB
Loading

docs/rosenbrock_AdamP.png

-1.98 KB
Loading

docs/rosenbrock_SGDP.png

453 KB
Loading

examples/viz_optimizers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def LookaheadYogi(*a, **kw):
166166
# Adam based
167167
(optim.AdaBound, -8, 0.3),
168168
(optim.AdaMod, -8, 0.2),
169+
(optim.AdamP, -8, 0.2),
169170
(optim.DiffGrad, -8, 0.4),
170171
(optim.Lamb, -8, -2.9),
171172
(optim.NovoGrad, -8, -1.7),
@@ -174,6 +175,7 @@ def LookaheadYogi(*a, **kw):
174175
# SGD/Momentum based
175176
(optim.AccSGD, -8, -1.4),
176177
(optim.SGDW, -8, -1.5),
178+
(optim.SGDP, -8, -1.5),
177179
(optim.PID, -8, -1.0),
178180
(optim.QHM, -6, -0.2),
179181
(optim.QHAdam, -8, 0.1),

tests/test_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def build_lookahead(*a, **kw):
6060
(optim.AccSGD, {'lr': 0.015}, 800),
6161
(build_lookahead, {'lr': 1.0}, 500),
6262
(optim.QHAdam, {'lr': 1.0}, 500),
63-
(optim.AdamP, {'lr': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-3}, 800)
63+
(optim.AdamP, {'lr': 0.01, 'betas': (0.9, 0.95), 'eps': 1e-3}, 800),
64+
(optim.SGDP, {'lr': 0.002, 'momentum': 0.91}, 900),
6465
]
6566

6667

tests/test_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def build_lookahead(*a, **kw):
7777
optim.QHM,
7878
optim.RAdam,
7979
optim.SGDW,
80+
optim.SGDP,
8081
optim.Yogi,
8182
build_lookahead,
8283
optim.Ranger,

0 commit comments

Comments
 (0)