Skip to content

Commit f12087b

Browse files
committed
just automatically sync batchnorm if more than one machine detected, allow for overriding with sync_batchnorm keyword argument
1 parent 6717204 commit f12087b

File tree

3 files changed

+41
-21
lines changed

3 files changed

+41
-21
lines changed

.github/workflows/python-publish.yml

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1-
# This workflows will upload a Python Package using Twine when a release is created
1+
# This workflow will upload a Python Package using Twine when a release is created
22
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
33

4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
49
name: Upload Python Package
510

611
on:
712
release:
8-
types: [created]
13+
types: [published]
914

1015
jobs:
1116
deploy:
@@ -21,11 +26,11 @@ jobs:
2126
- name: Install dependencies
2227
run: |
2328
python -m pip install --upgrade pip
24-
pip install setuptools wheel twine
25-
- name: Build and publish
26-
env:
27-
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28-
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29-
run: |
30-
python setup.py sdist bdist_wheel
31-
twine upload dist/*
29+
pip install build
30+
- name: Build package
31+
run: python -m build
32+
- name: Publish package
33+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34+
with:
35+
user: __token__
36+
password: ${{ secrets.PYPI_API_TOKEN }}

byol_pytorch/byol_pytorch.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch import nn
77
import torch.nn.functional as F
8+
import torch.distributed as dist
89

910
from torchvision import transforms as T
1011

@@ -37,6 +38,10 @@ def set_requires_grad(model, val):
3738
for p in model.parameters():
3839
p.requires_grad = val
3940

41+
def MaybeSyncBatchnorm(is_distributed = None):
42+
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
43+
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
44+
4045
# loss fn
4146

4247
def loss_fn(x, y):
@@ -75,32 +80,32 @@ def update_moving_average(ema_updater, ma_model, current_model):
7580

7681
# MLP class for projector and predictor
7782

78-
def MLP(dim, projection_size, hidden_size=4096):
83+
def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
7984
return nn.Sequential(
8085
nn.Linear(dim, hidden_size),
81-
nn.BatchNorm1d(hidden_size),
86+
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
8287
nn.ReLU(inplace=True),
8388
nn.Linear(hidden_size, projection_size)
8489
)
8590

86-
def SimSiamMLP(dim, projection_size, hidden_size=4096):
91+
def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
8792
return nn.Sequential(
8893
nn.Linear(dim, hidden_size, bias=False),
89-
nn.BatchNorm1d(hidden_size),
94+
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
9095
nn.ReLU(inplace=True),
9196
nn.Linear(hidden_size, hidden_size, bias=False),
92-
nn.BatchNorm1d(hidden_size),
97+
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
9398
nn.ReLU(inplace=True),
9499
nn.Linear(hidden_size, projection_size, bias=False),
95-
nn.BatchNorm1d(projection_size, affine=False)
100+
MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
96101
)
97102

98103
# a wrapper class for the base neural network
99104
# will manage the interception of the hidden layer output
100105
# and pipe it into the projecter and predictor nets
101106

102107
class NetWrapper(nn.Module):
103-
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False):
108+
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
104109
super().__init__()
105110
self.net = net
106111
self.layer = layer
@@ -110,6 +115,7 @@ def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use
110115
self.projection_hidden_size = projection_hidden_size
111116

112117
self.use_simsiam_mlp = use_simsiam_mlp
118+
self.sync_batchnorm = sync_batchnorm
113119

114120
self.hidden = {}
115121
self.hook_registered = False
@@ -137,7 +143,7 @@ def _register_hook(self):
137143
def _get_projector(self, hidden):
138144
_, dim = hidden.shape
139145
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
140-
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size)
146+
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
141147
return projector.to(hidden)
142148

143149
def get_representation(self, x):
@@ -178,7 +184,8 @@ def __init__(
178184
augment_fn = None,
179185
augment_fn2 = None,
180186
moving_average_decay = 0.99,
181-
use_momentum = True
187+
use_momentum = True,
188+
sync_batchnorm = None
182189
):
183190
super().__init__()
184191
self.net = net
@@ -205,7 +212,14 @@ def __init__(
205212
self.augment1 = default(augment_fn, DEFAULT_AUG)
206213
self.augment2 = default(augment_fn2, self.augment1)
207214

208-
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_simsiam_mlp=not use_momentum)
215+
self.online_encoder = NetWrapper(
216+
net,
217+
projection_size,
218+
projection_hidden_size,
219+
layer = hidden_layer,
220+
use_simsiam_mlp = not use_momentum,
221+
sync_batchnorm = sync_batchnorm
222+
)
209223

210224
self.use_momentum = use_momentum
211225
self.target_encoder = None

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
setup(
44
name = 'byol-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.6.0',
6+
version = '0.7.0',
77
license='MIT',
88
description = 'Self-supervised contrastive learning made simple',
99
author = 'Phil Wang',
1010
author_email = '[email protected]',
1111
url = 'https://github.com/lucidrains/byol-pytorch',
12+
long_description_content_type = 'text/markdown',
1213
keywords = [
1314
'self-supervised learning',
1415
'artificial intelligence'

0 commit comments

Comments
 (0)