Skip to content

Commit dabbef4

Browse files
huihuifanfacebook-github-bot
authored andcommitted
adding layerdrop code for training, pruning, and readme (#890)
Summary: TEST 1: EVALUATION TIME WORKS checked achieves correct model perplexity: 18.68 TEST 2: TRAINING NEW MODEL WORKS checked without layerdrop: --decoder-layerdrop 0 OR no flag at all | epoch 001: 10 / 11201 loss=27.469, nll_loss=27.469, ppl=185799477.36, wps=1764, ups=0, wpb=9216.000, bsz=3.000, num_updates=7, lr=0.0004376, gnorm=25.471, clip=1.000, oom=0.000, loss_scale=8.000, wall=37, train_wall=30 | epoch 001: 20 / 11201 loss=27.443, nll_loss=27.443, ppl=182500427.22, wps=2449, ups=0, wpb=9216.000, bsz=3.000, num_updates=17, lr=0.0010626, gnorm=25.273, clip=1.000, oom=0.000, loss_scale=8.000, wall=64, train_wall=57 | epoch 001: 30 / 11201 loss=27.404, nll_loss=27.404, ppl=177612215.78, wps=2720, ups=0, wpb=9216.000, bsz=3.000, num_updates=27, lr=0.0016876, gnorm=25.136, clip=1.000, oom=0.000, loss_scale=8.000, wall=91, train_wall=84 | epoch 001: 40 / 11201 loss=27.009, nll_loss=27.009, ppl=135079983.00, wps=2865, ups=0, wpb=9216.000, bsz=3.000, num_updates=37, lr=0.0023126, gnorm=24.311, clip=1.000, oom=0.000, loss_scale=8.000, wall=119, train_wall=112 | epoch 001: 50 / 11201 loss=26.418, nll_loss=26.418, ppl=89680259.41, wps=2952, ups=0, wpb=9216.000, bsz=3.000, num_updates=47, lr=0.0029376, gnorm=22.775, clip=1.000, oom=0.000, loss_scale=8.000, wall=147, train_wall=140 with layerdrop (regularization effect should be seen in PPL): --decoder-layerdrop 0.2 | epoch 001: 10 / 11201 loss=25.186, nll_loss=25.186, ppl=38182937.27, wps=2428, ups=0, wpb=9216.000, bsz=3.000, num_updates=8, lr=0.0005001, gnorm=17.082, clip=1.000, oom=0.000, loss_scale=16.000, wall=30, train_wall=24 | epoch 001: 20 / 11201 loss=25.270, nll_loss=25.270, ppl=40451933.50, wps=3173, ups=0, wpb=9216.000, bsz=3.000, num_updates=18, lr=0.0011251, gnorm=17.162, clip=1.000, oom=0.000, loss_scale=16.000, wall=52, train_wall=45 | epoch 001: 30 / 11201 loss=25.349, nll_loss=25.349, ppl=42752256.68, wps=3454, ups=0, wpb=9216.000, bsz=3.000, num_updates=28, lr=0.0017501, gnorm=17.370, clip=1.000, oom=0.000, loss_scale=16.000, wall=75, train_wall=68 | epoch 001: 40 / 11201 loss=25.115, nll_loss=25.115, ppl=36343806.30, wps=3619, ups=0, wpb=9216.000, bsz=3.000, num_updates=38, lr=0.0023751, gnorm=16.945, clip=1.000, oom=0.000, loss_scale=16.000, wall=97, train_wall=90 | epoch 001: 50 / 11201 loss=24.804, nll_loss=24.804, ppl=29284345.78, wps=3716, ups=0, wpb=9216.000, bsz=3.000, num_updates=48, lr=0.0030001, gnorm=16.406, clip=1.000, oom=0.000, loss_scale=16.000, wall=119, train_wall=112 TEST 3: PICKING UP TRAINING FROM EXISTING MODEL checked | loaded checkpoint /checkpoint/angelafan/structured_0.1_block_8_sd02/checkpoint_last.pt (epoch 272 @ 381066 updates) | loading train data for epoch 272 | loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train TEST 4: EVALUATING EXISTING BERT MODEL REPROS RESULTS | [input] dictionary: 50265 types | [label] dictionary: 9 types | Accuracy: 0.9231651376146789 achieves correct accuracy on SST2 for this model TEST 5: TRAINING NEW BERT MODEL WORKS checked and works TEST 6: NMT without layerdrop --encoder-layerdrop 0 --decoder-layerdrop 0 OR combinations of flag specified and not specified | epoch 001: 10 / 92203 loss=15.820, nll_loss=15.830, ppl=58267.93, wps=4902, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=7.207, clip=0.000, oom=0.000, loss_scale=128.000, wall=60, train_wall=3 | epoch 001: 20 / 92203 loss=15.523, nll_loss=15.501, ppl=46359.29, wps=5037, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.869, clip=0.000, oom=0.000, loss_scale=128.000, wall=63, train_wall=6 | epoch 001: 30 / 92203 loss=15.185, nll_loss=15.123, ppl=35695.79, wps=5085, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.186, clip=0.000, oom=0.000, loss_scale=128.000, wall=66, train_wall=9 | epoch 001: 40 / 92203 loss=14.940, nll_loss=14.849, ppl=29505.60, wps=5116, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=5.610, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=12 | epoch 001: 50 / 92203 loss=14.745, nll_loss=14.630, ppl=25346.87, wps=5070, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.104, clip=0.000, oom=0.000, loss_scale=128.000, wall=71, train_wall=15 with layerdrop (regularization effect should be seen in PPL) A) works with --encoder-layerdrop 0.2 --decoder-layerdrop 0.2 B) works with different settings --encoder-layerdrop 0.3 --decoder-layerdrop 0.5 C) works with one on and one off --encoder-layerdrop 0.2 --decoder-layerdrop 0 | epoch 001: 10 / 92203 loss=15.817, nll_loss=15.828, ppl=58158.54, wps=5355, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=6.959, clip=0.000, oom=0.000, loss_scale=128.000, wall=59, train_wall=3 | epoch 001: 20 / 92203 loss=15.650, nll_loss=15.641, ppl=51111.63, wps=5515, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.825, clip=0.000, oom=0.000, loss_scale=128.000, wall=61, train_wall=6 | epoch 001: 30 / 92203 loss=15.440, nll_loss=15.408, ppl=43491.58, wps=5602, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.576, clip=0.000, oom=0.000, loss_scale=128.000, wall=64, train_wall=8 | epoch 001: 40 / 92203 loss=15.247, nll_loss=15.193, ppl=37457.14, wps=5676, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=6.124, clip=0.000, oom=0.000, loss_scale=128.000, wall=67, train_wall=11 | epoch 001: 50 / 92203 loss=15.055, nll_loss=14.977, ppl=32259.92, wps=5598, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.661, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=14 TEST 7: PRUNING TESTCASES A) after adding the pruning flags, model can evaluate as a full model checked, reaches correct PPL num. model params: 246933504 | Evaluated 217646 tokens in 196.3s (1108.99 tokens/s) | Loss: 2.9275, Perplexity: 18.68 B) after adding pruning flags, model can be pruned. this works with multiple flag settings checked three cases: num. model params: 146163712 | Evaluated 217646 tokens in 106.0s (2054.07 tokens/s) | Loss: 3.0932, Perplexity: 22.05 num. model params: 209144832 | Evaluated 217646 tokens in 162.8s (1336.99 tokens/s) | Loss: 2.9526, Perplexity: 19.16 C) model can pick up training if you want to finetune the pruned model checked: | loading train data for epoch 272 | loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train | WARNING: overflow detected, setting loss scale to: 64.0 | WARNING: overflow detected, setting loss scale to: 32.0 | epoch 272: 1500 / 5601 loss=5.015, nll_loss=5.015, ppl=32.33, wps=11598, ups=1, wpb=18432.000, bsz=6.000, num_updates=98, lr=0.0061251, gnorm=0.613, clip=1.000, oom=0.000, loss_scale=32.000, wall=156, train_wall=252396 D) works with BERT checked: without specifying any flags, reproduces the correct standard accuracy with flags, produces the correct pruned accuracy | [input] dictionary: 50265 types | [label] dictionary: 9 types | Accuracy: 0.9231651376146789 | [input] dictionary: 50265 types | [label] dictionary: 9 types | Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop | Accuracy: 0.9220183486238532 Pull Request resolved: fairinternal/fairseq-py#890 Reviewed By: edunov Differential Revision: D18094657 Pulled By: huihuifan fbshipit-source-id: 2bbaa2ff0039e906782694fc2038b8c17a8693e7
1 parent eb68afc commit dabbef4

File tree

8 files changed

+209
-24
lines changed

8 files changed

+209
-24
lines changed

examples/layerdrop/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
2+
This page contains information for how to train models with LayerDrop.
3+
4+
Looking for pretrained models? They will be added shortly.
5+
6+
Looking for code for other forms of Structured Dropout? It will be added shortly.
7+
8+
## Citation:
9+
```bibtex
10+
@article{fan2019reducing,
11+
title={Reducing Transformer Depth on Demand with Structured Dropout},
12+
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
13+
journal={arXiv preprint arXiv:1909.11556},
14+
year={2019}
15+
}
16+
```
17+
18+
## Example usage
19+
20+
To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
21+
```
22+
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
23+
```
24+
25+
To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
26+
```
27+
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
28+
```
29+
Setting these flags should print a message such as:
30+
```
31+
| Pruning model to specified layer configuration
32+
```
33+
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
34+
```
35+
num. model params: 246933504
36+
```
37+
while a model pruned to 8 Layers prints:
38+
```
39+
num. model params: 146163712
40+
```
41+
42+
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
43+
```
44+
python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
45+
```
46+
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
47+
48+
49+
Looking to reproduce the results in the paper?
50+
51+
1. For Translation on WMT en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md)
52+
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta)
53+
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)
54+
55+
56+
## Tips
57+
58+
1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).
59+
60+
2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2.
61+
62+
3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.
63+
64+
## Having an issue or have a question?
65+
66+
Please open an issue in this repository with the details of your question. Thanks!

fairseq/checkpoint_utils.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
183183

184184
# build model for ensemble
185185
model = task.build_model(args)
186-
model.load_state_dict(state['model'], strict=True)
186+
model.load_state_dict(state['model'], strict=True, args=args)
187187
ensemble.append(model)
188188
return ensemble, args, task
189189

@@ -334,6 +334,70 @@ def _upgrade_state_dict(state):
334334
return state
335335

336336

337+
def prune_state_dict(state_dict, args):
338+
"""Prune the given state_dict if desired for LayerDrop
339+
(https://arxiv.org/abs/1909.11556).
340+
341+
Training with LayerDrop allows models to be robust to pruning at inference
342+
time. This function prunes state_dict to allow smaller models to be loaded
343+
from a larger model and re-maps the existing state_dict for this to occur.
344+
345+
It's called by functions that load models from checkpoints and does not
346+
need to be called directly.
347+
"""
348+
if not args:
349+
# args should not be none, but don't crash if it is.
350+
return state_dict
351+
352+
encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
353+
decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
354+
355+
if not encoder_layers_to_keep and not decoder_layers_to_keep:
356+
return state_dict
357+
358+
# apply pruning
359+
print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop")
360+
361+
def create_pruning_pass(layers_to_keep, layer_name):
362+
keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")])
363+
mapping_dict = {}
364+
for i in range(len(keep_layers)):
365+
mapping_dict[str(keep_layers[i])] = str(i)
366+
367+
regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
368+
return {
369+
"substitution_regex": regex,
370+
"mapping_dict": mapping_dict
371+
}
372+
373+
pruning_passes = []
374+
if encoder_layers_to_keep:
375+
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
376+
if decoder_layers_to_keep:
377+
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
378+
379+
new_state_dict = {}
380+
for layer_name in state_dict.keys():
381+
match = re.search("\.layers\.(\d+)\.", layer_name)
382+
# if layer has no number in it, it is a supporting layer, such as an
383+
# embedding
384+
if not match:
385+
new_state_dict[layer_name] = state_dict[layer_name]
386+
continue
387+
388+
# otherwise, layer should be pruned.
389+
original_layer_number = match.group(1)
390+
# figure out which mapping dict to replace from
391+
for pruning_pass in pruning_passes:
392+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
393+
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
394+
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
395+
new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):]
396+
new_state_dict[new_state_key] = state_dict[layer_name]
397+
398+
return new_state_dict
399+
400+
337401
def load_pretrained_component_from_model(
338402
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
339403
):

fairseq/models/fairseq_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414

1515
from fairseq import utils
16+
from fairseq.checkpoint_utils import prune_state_dict
1617
from fairseq.data import Dictionary
1718
from fairseq.models import FairseqDecoder, FairseqEncoder
1819

@@ -58,15 +59,16 @@ def max_positions(self):
5859
"""Maximum length supported by the model."""
5960
return None
6061

61-
def load_state_dict(self, state_dict, strict=True):
62+
def load_state_dict(self, state_dict, strict=True, args=None):
6263
"""Copies parameters and buffers from *state_dict* into this module and
6364
its descendants.
6465
6566
Overrides the method in :class:`nn.Module`. Compared with that method
6667
this additionally "upgrades" *state_dicts* from old checkpoints.
6768
"""
6869
self.upgrade_state_dict(state_dict)
69-
return super().load_state_dict(state_dict, strict)
70+
new_state_dict = prune_state_dict(state_dict, args)
71+
return super().load_state_dict(new_state_dict, strict)
7072

7173
def upgrade_state_dict(self, state_dict):
7274
"""Upgrade old state dicts to work with newer code."""

fairseq/models/roberta/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ def add_args(parser):
7878
help='number of positional embeddings to learn')
7979
parser.add_argument('--load-checkpoint-heads', action='store_true',
8080
help='(re-)register and load heads when loading checkpoints')
81+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
82+
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
83+
help='LayerDrop probability for encoder')
84+
parser.add_argument('--encoder-layers-to-keep', default=None,
85+
help='which layers to *keep* when pruning as a comma-separated list')
8186

8287
@classmethod
8388
def build_model(cls, args, task):
@@ -245,6 +250,15 @@ class RobertaEncoder(FairseqDecoder):
245250
def __init__(self, args, dictionary):
246251
super().__init__(dictionary)
247252
self.args = args
253+
254+
# RoBERTa is a sentence encoder model, so users will intuitively trim
255+
# encoder layers. However, the implementation uses the fairseq decoder,
256+
# so we fix here.
257+
if args.encoder_layers_to_keep:
258+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
259+
args.decoder_layers_to_keep = args.encoder_layers_to_keep
260+
args.encoder_layers_to_keep = None
261+
248262
self.sentence_encoder = TransformerSentenceEncoder(
249263
padding_idx=dictionary.pad(),
250264
vocab_size=len(dictionary),
@@ -255,6 +269,7 @@ def __init__(self, args, dictionary):
255269
dropout=args.dropout,
256270
attention_dropout=args.attention_dropout,
257271
activation_dropout=args.activation_dropout,
272+
layerdrop=args.encoder_layerdrop,
258273
max_seq_len=args.max_positions,
259274
num_segments=0,
260275
encoder_normalize_before=True,

fairseq/models/transformer.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TransformerDecoderLayer,
2626
TransformerEncoderLayer,
2727
)
28+
import random
2829

2930
DEFAULT_MAX_SOURCE_POSITIONS = 1024
3031
DEFAULT_MAX_TARGET_POSITIONS = 1024
@@ -130,6 +131,15 @@ def add_args(parser):
130131
help='perform cross+self-attention')
131132
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
132133
help='perform layer-wise attention (cross-attention or cross+self-attention)')
134+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
135+
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
136+
help='LayerDrop probability for encoder')
137+
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
138+
help='LayerDrop probability for decoder')
139+
parser.add_argument('--encoder-layers-to-keep', default=None,
140+
help='which layers to *keep* when pruning as a comma-separated list')
141+
parser.add_argument('--decoder-layers-to-keep', default=None,
142+
help='which layers to *keep* when pruning as a comma-separated list')
133143
# fmt: on
134144

135145
@classmethod
@@ -139,6 +149,11 @@ def build_model(cls, args, task):
139149
# make sure all arguments are present in older models
140150
base_architecture(args)
141151

152+
if args.encoder_layers_to_keep:
153+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
154+
if args.decoder_layers_to_keep:
155+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
156+
142157
if not hasattr(args, 'max_source_positions'):
143158
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
144159
if not hasattr(args, 'max_target_positions'):
@@ -275,6 +290,7 @@ def __init__(self, args, dictionary, embed_tokens):
275290
self.register_buffer('version', torch.Tensor([3]))
276291

277292
self.dropout = args.dropout
293+
self.encoder_layerdrop = args.encoder_layerdrop
278294

279295
embed_dim = embed_tokens.embedding_dim
280296
self.padding_idx = embed_tokens.padding_idx
@@ -300,6 +316,7 @@ def __init__(self, args, dictionary, embed_tokens):
300316
else:
301317
self.layer_norm = None
302318

319+
303320
def forward_embedding(self, src_tokens):
304321
# embed tokens and positions
305322
embed = self.embed_scale * self.embed_tokens(src_tokens)
@@ -345,9 +362,12 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa
345362

346363
# encoder layers
347364
for layer in self.layers:
348-
x = layer(x, encoder_padding_mask)
349-
if return_all_hiddens:
350-
encoder_states.append(x)
365+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
366+
dropout_probability = random.uniform(0, 1)
367+
if not self.training or (dropout_probability > self.encoder_layerdrop):
368+
x = layer(x, encoder_padding_mask)
369+
if return_all_hiddens:
370+
encoder_states.append(x)
351371

352372
if self.layer_norm:
353373
x = self.layer_norm(x)
@@ -435,6 +455,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
435455
self.register_buffer('version', torch.Tensor([3]))
436456

437457
self.dropout = args.dropout
458+
self.decoder_layerdrop = args.decoder_layerdrop
438459
self.share_input_output_embed = args.share_decoder_input_output_embed
439460

440461
input_embed_dim = embed_tokens.embedding_dim
@@ -594,20 +615,22 @@ def extract_features(
594615
else:
595616
self_attn_mask = None
596617

597-
x, layer_attn = layer(
598-
x,
599-
encoder_state,
600-
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
601-
incremental_state,
602-
self_attn_mask=self_attn_mask,
603-
self_attn_padding_mask=self_attn_padding_mask,
604-
need_attn=(idx == alignment_layer),
605-
need_head_weights=(idx == alignment_layer),
606-
)
607-
608-
inner_states.append(x)
609-
if layer_attn is not None and idx == alignment_layer:
610-
attn = layer_attn.float()
618+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
619+
dropout_probability = random.uniform(0, 1)
620+
if not self.training or (dropout_probability > self.decoder_layerdrop):
621+
x, layer_attn = layer(
622+
x,
623+
encoder_state,
624+
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
625+
incremental_state,
626+
self_attn_mask=self_attn_mask,
627+
self_attn_padding_mask=self_attn_padding_mask,
628+
need_attn=(idx == alignment_layer),
629+
need_head_weights=(idx == alignment_layer),
630+
)
631+
inner_states.append(x)
632+
if layer_attn is not None and idx == alignment_layer:
633+
attn = layer_attn.float()
611634

612635
if attn is not None:
613636
if alignment_heads is not None:

fairseq/models/transformer_lm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ def add_args(parser):
9898
help='if set, ties the projection weights of adaptive softmax and adaptive input')
9999
parser.add_argument('--decoder-learned-pos', action='store_true',
100100
help='use learned positional embeddings in the decoder')
101+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
102+
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
103+
help='LayerDrop probability for decoder')
104+
parser.add_argument('--decoder-layers-to-keep', default=None,
105+
help='which layers to *keep* when pruning as a comma-separated list')
101106
# fmt: on
102107

103108
@classmethod
@@ -107,6 +112,9 @@ def build_model(cls, args, task):
107112
# make sure all arguments are present in older models
108113
base_lm_architecture(args)
109114

115+
if args.decoder_layers_to_keep:
116+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
117+
110118
if getattr(args, 'max_target_positions', None) is None:
111119
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
112120

fairseq/modules/transformer_sentence_encoder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
PositionalEmbedding,
1515
TransformerSentenceEncoderLayer,
1616
)
17+
import random
1718

1819

1920
def init_bert_params(module):
@@ -77,6 +78,7 @@ def __init__(
7778
dropout: float = 0.1,
7879
attention_dropout: float = 0.1,
7980
activation_dropout: float = 0.1,
81+
layerdrop : float = 0.0,
8082
max_seq_len: int = 256,
8183
num_segments: int = 2,
8284
use_position_embeddings: bool = True,
@@ -97,6 +99,7 @@ def __init__(
9799
self.padding_idx = padding_idx
98100
self.vocab_size = vocab_size
99101
self.dropout = dropout
102+
self.layerdrop = layerdrop
100103
self.max_seq_len = max_seq_len
101104
self.embedding_dim = embedding_dim
102105
self.num_segments = num_segments
@@ -208,9 +211,13 @@ def forward(
208211
inner_states.append(x)
209212

210213
for layer in self.layers:
211-
x, _ = layer(x, self_attn_padding_mask=padding_mask)
212-
if not last_state_only:
213-
inner_states.append(x)
214+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
215+
dropout_probability = random.uniform(0, 1)
216+
if not self.training or (dropout_probability > self.layerdrop):
217+
x, _ = layer(x, self_attn_padding_mask=padding_mask)
218+
if not last_state_only:
219+
inner_states.append(x)
220+
214221

215222
# T x B x C -> B x T x C
216223
x = x.transpose(0, 1)

fairseq/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def load_checkpoint(
181181

182182
# load model parameters
183183
try:
184-
self.get_model().load_state_dict(state['model'], strict=True)
184+
self.get_model().load_state_dict(state['model'], strict=True, args=self.args)
185185
if utils.has_parameters(self.get_criterion()):
186186
self.get_criterion().load_state_dict(state['criterion'], strict=True)
187187
except Exception:

0 commit comments

Comments
 (0)