Skip to content

Commit 307df56

Browse files
author
Myle Ott
committed
Refactor RobertaModel base class (fixes #2186)
1 parent 95294bf commit 307df56

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

fairseq/models/roberta/model.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from fairseq import utils
1616
from fairseq.models import (
17-
FairseqDecoder,
18-
FairseqLanguageModel,
17+
FairseqEncoder,
18+
FairseqEncoderModel,
1919
register_model,
2020
register_model_architecture,
2121
)
@@ -33,7 +33,7 @@
3333

3434

3535
@register_model('roberta')
36-
class RobertaModel(FairseqLanguageModel):
36+
class RobertaModel(FairseqEncoderModel):
3737

3838
@classmethod
3939
def hub_models(cls):
@@ -116,12 +116,20 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla
116116
if classification_head_name is not None:
117117
features_only = True
118118

119-
x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
119+
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
120120

121121
if classification_head_name is not None:
122122
x = self.classification_heads[classification_head_name](x)
123123
return x, extra
124124

125+
def get_normalized_probs(self, net_output, log_probs, sample=None):
126+
"""Get normalized probabilities (or log probs) from a net's output."""
127+
logits = net_output[0].float()
128+
if log_probs:
129+
return F.log_softmax(logits, dim=-1)
130+
else:
131+
return F.softmax(logits, dim=-1)
132+
125133
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
126134
"""Register a classification head."""
127135
if name in self.classification_heads:
@@ -163,13 +171,23 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
163171
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
164172

165173
def upgrade_state_dict_named(self, state_dict, name):
166-
super().upgrade_state_dict_named(state_dict, name)
167-
168174
prefix = name + '.' if name != '' else ''
169-
current_head_names = [] if not hasattr(self, 'classification_heads') else \
170-
self.classification_heads.keys()
175+
176+
# rename decoder -> encoder before upgrading children modules
177+
for k in list(state_dict.keys()):
178+
if k.startswith(prefix + 'decoder'):
179+
new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
180+
state_dict[new_k] = state_dict[k]
181+
del state_dict[k]
182+
183+
# upgrade children modules
184+
super().upgrade_state_dict_named(state_dict, name)
171185

172186
# Handle new classification heads present in the state dict.
187+
current_head_names = (
188+
[] if not hasattr(self, 'classification_heads')
189+
else self.classification_heads.keys()
190+
)
173191
keys_to_delete = []
174192
for k in state_dict.keys():
175193
if not k.startswith(prefix + 'classification_heads.'):
@@ -261,24 +279,15 @@ def forward(self, features, **kwargs):
261279
return x
262280

263281

264-
class RobertaEncoder(FairseqDecoder):
265-
"""RoBERTa encoder.
266-
267-
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
268-
by :class:`~fairseq.models.FairseqLanguageModel`.
269-
"""
282+
class RobertaEncoder(FairseqEncoder):
283+
"""RoBERTa encoder."""
270284

271285
def __init__(self, args, dictionary):
272286
super().__init__(dictionary)
273287
self.args = args
274288

275-
# RoBERTa is a sentence encoder model, so users will intuitively trim
276-
# encoder layers. However, the implementation uses the fairseq decoder,
277-
# so we fix here.
278289
if args.encoder_layers_to_keep:
279290
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
280-
args.decoder_layers_to_keep = args.encoder_layers_to_keep
281-
args.encoder_layers_to_keep = None
282291

283292
self.sentence_encoder = TransformerSentenceEncoder(
284293
padding_idx=dictionary.pad(),

0 commit comments

Comments
 (0)