5
5
import torch
6
6
from torch import nn
7
7
import torch .nn .functional as F
8
+ import torch .distributed as dist
8
9
9
10
from torchvision import transforms as T
10
11
@@ -37,6 +38,10 @@ def set_requires_grad(model, val):
37
38
for p in model .parameters ():
38
39
p .requires_grad = val
39
40
41
+ def MaybeSyncBatchnorm (is_distributed = None ):
42
+ is_distributed = default (is_distributed , dist .is_initialized () and dist .get_world_size () > 1 )
43
+ return nn .SyncBatchNorm if is_distributed else nn .BatchNorm1d
44
+
40
45
# loss fn
41
46
42
47
def loss_fn (x , y ):
@@ -75,32 +80,32 @@ def update_moving_average(ema_updater, ma_model, current_model):
75
80
76
81
# MLP class for projector and predictor
77
82
78
- def MLP (dim , projection_size , hidden_size = 4096 ):
83
+ def MLP (dim , projection_size , hidden_size = 4096 , sync_batchnorm = None ):
79
84
return nn .Sequential (
80
85
nn .Linear (dim , hidden_size ),
81
- nn . BatchNorm1d (hidden_size ),
86
+ MaybeSyncBatchnorm ( sync_batchnorm ) (hidden_size ),
82
87
nn .ReLU (inplace = True ),
83
88
nn .Linear (hidden_size , projection_size )
84
89
)
85
90
86
- def SimSiamMLP (dim , projection_size , hidden_size = 4096 ):
91
+ def SimSiamMLP (dim , projection_size , hidden_size = 4096 , sync_batchnorm = None ):
87
92
return nn .Sequential (
88
93
nn .Linear (dim , hidden_size , bias = False ),
89
- nn . BatchNorm1d (hidden_size ),
94
+ MaybeSyncBatchnorm ( sync_batchnorm ) (hidden_size ),
90
95
nn .ReLU (inplace = True ),
91
96
nn .Linear (hidden_size , hidden_size , bias = False ),
92
- nn . BatchNorm1d (hidden_size ),
97
+ MaybeSyncBatchnorm ( sync_batchnorm ) (hidden_size ),
93
98
nn .ReLU (inplace = True ),
94
99
nn .Linear (hidden_size , projection_size , bias = False ),
95
- nn . BatchNorm1d (projection_size , affine = False )
100
+ MaybeSyncBatchnorm ( sync_batchnorm ) (projection_size , affine = False )
96
101
)
97
102
98
103
# a wrapper class for the base neural network
99
104
# will manage the interception of the hidden layer output
100
105
# and pipe it into the projecter and predictor nets
101
106
102
107
class NetWrapper (nn .Module ):
103
- def __init__ (self , net , projection_size , projection_hidden_size , layer = - 2 , use_simsiam_mlp = False ):
108
+ def __init__ (self , net , projection_size , projection_hidden_size , layer = - 2 , use_simsiam_mlp = False , sync_batchnorm = None ):
104
109
super ().__init__ ()
105
110
self .net = net
106
111
self .layer = layer
@@ -110,6 +115,7 @@ def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use
110
115
self .projection_hidden_size = projection_hidden_size
111
116
112
117
self .use_simsiam_mlp = use_simsiam_mlp
118
+ self .sync_batchnorm = sync_batchnorm
113
119
114
120
self .hidden = {}
115
121
self .hook_registered = False
@@ -137,7 +143,7 @@ def _register_hook(self):
137
143
def _get_projector (self , hidden ):
138
144
_ , dim = hidden .shape
139
145
create_mlp_fn = MLP if not self .use_simsiam_mlp else SimSiamMLP
140
- projector = create_mlp_fn (dim , self .projection_size , self .projection_hidden_size )
146
+ projector = create_mlp_fn (dim , self .projection_size , self .projection_hidden_size , sync_batchnorm = self . sync_batchnorm )
141
147
return projector .to (hidden )
142
148
143
149
def get_representation (self , x ):
@@ -178,7 +184,8 @@ def __init__(
178
184
augment_fn = None ,
179
185
augment_fn2 = None ,
180
186
moving_average_decay = 0.99 ,
181
- use_momentum = True
187
+ use_momentum = True ,
188
+ sync_batchnorm = None
182
189
):
183
190
super ().__init__ ()
184
191
self .net = net
@@ -205,7 +212,14 @@ def __init__(
205
212
self .augment1 = default (augment_fn , DEFAULT_AUG )
206
213
self .augment2 = default (augment_fn2 , self .augment1 )
207
214
208
- self .online_encoder = NetWrapper (net , projection_size , projection_hidden_size , layer = hidden_layer , use_simsiam_mlp = not use_momentum )
215
+ self .online_encoder = NetWrapper (
216
+ net ,
217
+ projection_size ,
218
+ projection_hidden_size ,
219
+ layer = hidden_layer ,
220
+ use_simsiam_mlp = not use_momentum ,
221
+ sync_batchnorm = sync_batchnorm
222
+ )
209
223
210
224
self .use_momentum = use_momentum
211
225
self .target_encoder = None
0 commit comments