|
14 | 14 |
|
15 | 15 | from fairseq import utils
|
16 | 16 | from fairseq.models import (
|
17 |
| - FairseqDecoder, |
18 |
| - FairseqLanguageModel, |
| 17 | + FairseqEncoder, |
| 18 | + FairseqEncoderModel, |
19 | 19 | register_model,
|
20 | 20 | register_model_architecture,
|
21 | 21 | )
|
|
33 | 33 |
|
34 | 34 |
|
35 | 35 | @register_model('roberta')
|
36 |
| -class RobertaModel(FairseqLanguageModel): |
| 36 | +class RobertaModel(FairseqEncoderModel): |
37 | 37 |
|
38 | 38 | @classmethod
|
39 | 39 | def hub_models(cls):
|
@@ -116,12 +116,20 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, cla
|
116 | 116 | if classification_head_name is not None:
|
117 | 117 | features_only = True
|
118 | 118 |
|
119 |
| - x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs) |
| 119 | + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) |
120 | 120 |
|
121 | 121 | if classification_head_name is not None:
|
122 | 122 | x = self.classification_heads[classification_head_name](x)
|
123 | 123 | return x, extra
|
124 | 124 |
|
| 125 | + def get_normalized_probs(self, net_output, log_probs, sample=None): |
| 126 | + """Get normalized probabilities (or log probs) from a net's output.""" |
| 127 | + logits = net_output[0].float() |
| 128 | + if log_probs: |
| 129 | + return F.log_softmax(logits, dim=-1) |
| 130 | + else: |
| 131 | + return F.softmax(logits, dim=-1) |
| 132 | + |
125 | 133 | def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
126 | 134 | """Register a classification head."""
|
127 | 135 | if name in self.classification_heads:
|
@@ -163,13 +171,23 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na
|
163 | 171 | return RobertaHubInterface(x['args'], x['task'], x['models'][0])
|
164 | 172 |
|
165 | 173 | def upgrade_state_dict_named(self, state_dict, name):
|
166 |
| - super().upgrade_state_dict_named(state_dict, name) |
167 |
| - |
168 | 174 | prefix = name + '.' if name != '' else ''
|
169 |
| - current_head_names = [] if not hasattr(self, 'classification_heads') else \ |
170 |
| - self.classification_heads.keys() |
| 175 | + |
| 176 | + # rename decoder -> encoder before upgrading children modules |
| 177 | + for k in list(state_dict.keys()): |
| 178 | + if k.startswith(prefix + 'decoder'): |
| 179 | + new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):] |
| 180 | + state_dict[new_k] = state_dict[k] |
| 181 | + del state_dict[k] |
| 182 | + |
| 183 | + # upgrade children modules |
| 184 | + super().upgrade_state_dict_named(state_dict, name) |
171 | 185 |
|
172 | 186 | # Handle new classification heads present in the state dict.
|
| 187 | + current_head_names = ( |
| 188 | + [] if not hasattr(self, 'classification_heads') |
| 189 | + else self.classification_heads.keys() |
| 190 | + ) |
173 | 191 | keys_to_delete = []
|
174 | 192 | for k in state_dict.keys():
|
175 | 193 | if not k.startswith(prefix + 'classification_heads.'):
|
@@ -261,24 +279,15 @@ def forward(self, features, **kwargs):
|
261 | 279 | return x
|
262 | 280 |
|
263 | 281 |
|
264 |
| -class RobertaEncoder(FairseqDecoder): |
265 |
| - """RoBERTa encoder. |
266 |
| -
|
267 |
| - Implements the :class:`~fairseq.models.FairseqDecoder` interface required |
268 |
| - by :class:`~fairseq.models.FairseqLanguageModel`. |
269 |
| - """ |
| 282 | +class RobertaEncoder(FairseqEncoder): |
| 283 | + """RoBERTa encoder.""" |
270 | 284 |
|
271 | 285 | def __init__(self, args, dictionary):
|
272 | 286 | super().__init__(dictionary)
|
273 | 287 | self.args = args
|
274 | 288 |
|
275 |
| - # RoBERTa is a sentence encoder model, so users will intuitively trim |
276 |
| - # encoder layers. However, the implementation uses the fairseq decoder, |
277 |
| - # so we fix here. |
278 | 289 | if args.encoder_layers_to_keep:
|
279 | 290 | args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
280 |
| - args.decoder_layers_to_keep = args.encoder_layers_to_keep |
281 |
| - args.encoder_layers_to_keep = None |
282 | 291 |
|
283 | 292 | self.sentence_encoder = TransformerSentenceEncoder(
|
284 | 293 | padding_idx=dictionary.pad(),
|
|
0 commit comments