Skip to content

Commit 2aa84ee

Browse files
committed
default to returning projections, but can be turned off with return_projection = False on forward
1 parent 8c08cef commit 2aa84ee

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

byol_pytorch/byol_pytorch.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def get_representation(self, x):
145145
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
146146
return hidden
147147

148-
def forward(self, x, return_embedding = False):
148+
def forward(self, x, return_projection = True):
149149
representation = self.get_representation(x)
150150

151-
if return_embedding:
151+
if not return_projection:
152152
return representation
153153

154154
projector = self._get_projector(representation)
@@ -225,9 +225,14 @@ def update_moving_average(self):
225225
assert self.target_encoder is not None, 'target encoder has not been created yet'
226226
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
227227

228-
def forward(self, x, return_embedding = False):
228+
def forward(
229+
self,
230+
x,
231+
return_embedding = False,
232+
return_projection = True
233+
):
229234
if return_embedding:
230-
return self.online_encoder(x, True)
235+
return self.online_encoder(x, return_projection = return_projection)
231236

232237
image_one, image_two = self.augment1(x), self.augment2(x)
233238

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'byol-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.5.5',
6+
version = '0.5.6',
77
license='MIT',
88
description = 'Self-supervised contrastive learning made simple',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)