Skip to content

Commit a92bcda

Browse files
Naman Goyalfacebook-github-bot
authored andcommitted
adding first version of bart code release (#902)
Summary: This is the first version of BART code / model release. It still requires lot of clean up, instructions, making sure results are reproducible before we can release it. Pull Request resolved: fairinternal/fairseq-py#902 Differential Revision: D18389535 fbshipit-source-id: 77f16800307ce831bd29538fdd34800793210f46
1 parent e98bf7e commit a92bcda

18 files changed

+1360
-39
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ modeling and other text generation tasks.
66

77
### What's New:
88

9+
- November 2019: [BART model and code released](examples/bart/README.md)
910
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
1011
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
1112
- August 2019: [WMT'19 models released](examples/wmt19/README.md)

examples/bart/README.glue.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Fine-tuning BART on GLUE tasks
2+
3+
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4+
```bash
5+
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6+
python download_glue_data.py --data_dir glue_data --tasks all
7+
```
8+
9+
### 2) Preprocess GLUE task data (same as RoBERTa):
10+
```bash
11+
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12+
```
13+
`glue_task_name` is one of the following:
14+
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15+
Use `ALL` for preprocessing all the glue tasks.
16+
17+
### 3) Fine-tuning on GLUE task:
18+
Example fine-tuning cmd for `RTE` task
19+
```bash
20+
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21+
WARMUP_UPDATES=61 # 6 percent of the number of updates
22+
LR=1e-05 # Peak LR for polynomial LR scheduler.
23+
NUM_CLASSES=2
24+
MAX_SENTENCES=16 # Batch size.
25+
BART_PATH=/path/to/bart/model.pt
26+
27+
CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \
28+
--restore-file $BART_PATH \
29+
--max-sentences $MAX_SENTENCES \
30+
--max-tokens 4400 \
31+
--task sentence_prediction \
32+
--add-prev-output-tokens \
33+
--layernorm-embedding \
34+
--share-all-embeddings \
35+
--share-decoder-input-output-embed \
36+
--reset-optimizer --reset-dataloader --reset-meters \
37+
--required-batch-size-multiple 1 \
38+
--init-token 0 \
39+
--arch bart_large \
40+
--criterion sentence_prediction \
41+
--num-classes $NUM_CLASSES \
42+
--dropout 0.1 --attention-dropout 0.1 \
43+
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44+
--clip-norm 0.0 \
45+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47+
--max-epoch 10 \
48+
--find-unused-parameters \
49+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50+
```
51+
52+
For each of the GLUE task, you will need to use following cmd-line arguments:
53+
54+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55+
---|---|---|---|---|---|---|---|---
56+
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57+
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58+
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59+
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60+
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61+
62+
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63+
64+
**Note:**
65+
66+
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task.
67+
68+
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
69+
70+
### Inference on GLUE task
71+
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72+
73+
```python
74+
from fairseq.models.bart import BARTModel
75+
76+
bart = BARTModel.from_pretrained(
77+
'checkpoints/',
78+
checkpoint_file='checkpoint_best.pt',
79+
data_name_or_path='RTE-bin'
80+
)
81+
82+
label_fn = lambda label: bart.task.label_dictionary.string(
83+
[label + bart.task.label_dictionary.nspecial]
84+
)
85+
ncorrect, nsamples = 0, 0
86+
bart.cuda()
87+
bart.eval()
88+
with open('glue_data/RTE/dev.tsv') as fin:
89+
fin.readline()
90+
for index, line in enumerate(fin):
91+
tokens = line.strip().split('\t')
92+
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93+
tokens = bart.encode(sent1, sent2)
94+
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95+
prediction_label = label_fn(prediction)
96+
ncorrect += int(prediction_label == target)
97+
nsamples += 1
98+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
99+
```

examples/bart/README.md

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2+
3+
[https://arxiv.org/pdf/1910.13461.pdf]
4+
5+
## Introduction
6+
7+
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8+
9+
## Pre-trained models
10+
11+
Model | Description | # params | Download
12+
---|---|---|---
13+
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
14+
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
15+
16+
## Results
17+
18+
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
19+
_(dev set, single model, single-task finetuning)_
20+
21+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
22+
---|---|---|---|---|---|---|---|---
23+
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
24+
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
25+
26+
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
27+
_(dev set, no additional data used)_
28+
29+
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
30+
---|---|---
31+
`roberta.large` | 88.9/94.6 | 86.5/89.4
32+
`bart.large` | 88.8/94.6 | 86.1/89.2
33+
34+
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
35+
_(dev set, no additional data used)_
36+
37+
Model | R1 | R2 | RL
38+
---|---|---|---
39+
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
40+
`bart.large` | 44.16 | 21.28 | 40.90
41+
42+
## Example usage
43+
44+
##### Load BART from torch.hub (PyTorch >= 1.1):
45+
```python
46+
import torch
47+
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
48+
bart.eval() # disable dropout (or leave in train mode to finetune)
49+
```
50+
51+
##### Load BART (for PyTorch 1.0 or custom models):
52+
```python
53+
# Download bart.large model
54+
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
55+
tar -xzvf bart.large.tar.gz
56+
57+
# Load the model in fairseq
58+
from fairseq.models.bart import BARTModel
59+
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
60+
bart.eval() # disable dropout (or leave in train mode to finetune)
61+
```
62+
63+
##### Apply Byte-Pair Encoding (BPE) to input text:
64+
```python
65+
tokens = bart.encode('Hello world!')
66+
assert tokens.tolist() == [0, 31414, 232, 328, 2]
67+
bart.decode(tokens) # 'Hello world!'
68+
```
69+
70+
##### Extract features from BART:
71+
```python
72+
# Extract the last layer's features
73+
last_layer_features = bart.extract_features(tokens)
74+
assert last_layer_features.size() == torch.Size([1, 5, 1024])
75+
76+
# Extract all layer's features from decoder (layer 0 is the embedding layer)
77+
all_layers = bart.extract_features(tokens, return_all_hiddens=True)
78+
assert len(all_layers) == 13
79+
assert torch.all(all_layers[-1] == last_layer_features)
80+
```
81+
82+
##### Use BART for sentence-pair classification tasks:
83+
```python
84+
# Download BART already finetuned for MNLI
85+
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
86+
bart.eval() # disable dropout for evaluation
87+
88+
# Encode a pair of sentences and make a prediction
89+
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
90+
bart.predict('mnli', tokens).argmax() # 0: contradiction
91+
92+
# Encode another pair of sentences
93+
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
94+
bart.predict('mnli', tokens).argmax() # 2: entailment
95+
```
96+
97+
##### Register a new (randomly initialized) classification head:
98+
```python
99+
bart.register_classification_head('new_task', num_classes=3)
100+
logprobs = bart.predict('new_task', tokens)
101+
```
102+
103+
##### Batched prediction:
104+
```python
105+
import torch
106+
from fairseq.data.data_utils import collate_tokens
107+
108+
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
109+
bart.eval()
110+
111+
batch_of_pairs = [
112+
['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
113+
['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
114+
]
115+
116+
batch = collate_tokens(
117+
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
118+
)
119+
120+
logprobs = bart.predict('mnli', batch)
121+
print(logprobs.argmax(dim=1))
122+
# tensor([0, 2])
123+
```
124+
125+
##### Using the GPU:
126+
```python
127+
bart.cuda()
128+
bart.predict('new_task', tokens)
129+
```
130+
131+
#### Evaluating the `bart.large.mnli` model:
132+
133+
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
134+
```python
135+
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
136+
ncorrect, nsamples = 0, 0
137+
bart.cuda()
138+
bart.eval()
139+
with open('glue_data/MNLI/dev_matched.tsv') as fin:
140+
fin.readline()
141+
for index, line in enumerate(fin):
142+
tokens = line.strip().split('\t')
143+
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
144+
tokens = bart.encode(sent1, sent2)
145+
prediction = bart.predict('mnli', tokens).argmax().item()
146+
prediction_label = label_map[prediction]
147+
ncorrect += int(prediction_label == target)
148+
nsamples += 1
149+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
150+
# Expected output: 0.9010
151+
```
152+
153+
## Finetuning
154+
155+
- [Finetuning on GLUE](README.glue.md)
156+
157+
## Citation
158+
159+
```bibtex
160+
@article{lewis2019bart,
161+
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
162+
Language Generation, Translation, and Comprehension},
163+
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
164+
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
165+
and Luke Zettlemoyer },
166+
journal={arXiv preprint arXiv:1910.13461},
167+
year = {2019},
168+
}
169+
```

examples/roberta/README.glue.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ roberta = RobertaModel.from_pretrained(
7979
)
8080

8181
label_fn = lambda label: roberta.task.label_dictionary.string(
82-
[label + roberta.task.target_dictionary.nspecial]
82+
[label + roberta.task.label_dictionary.nspecial]
8383
)
8484
ncorrect, nsamples = 0, 0
8585
roberta.cuda()

fairseq/data/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
from .base_wrapper_dataset import BaseWrapperDataset
1111

12+
from .append_token_dataset import AppendTokenDataset
1213
from .audio.raw_audio_dataset import FileAudioDataset
1314
from .backtranslation_dataset import BacktranslationDataset
1415
from .colorize_dataset import ColorizeDataset
1516
from .concat_dataset import ConcatDataset
1617
from .concat_sentences_dataset import ConcatSentencesDataset
18+
from .denoising_dataset import DenoisingDataset
1719
from .id_dataset import IdDataset
1820
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
1921
from .language_pair_dataset import LanguagePairDataset
@@ -33,6 +35,7 @@
3335
from .raw_label_dataset import RawLabelDataset
3436
from .replace_dataset import ReplaceDataset
3537
from .resampling_dataset import ResamplingDataset
38+
from .roll_dataset import RollDataset
3639
from .round_robin_zip_datasets import RoundRobinZipDatasets
3740
from .sharded_dataset import ShardedDataset
3841
from .sort_dataset import SortDataset
@@ -42,7 +45,6 @@
4245
from .transform_eos_dataset import TransformEosDataset
4346
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
4447
from .truncate_dataset import TruncateDataset
45-
from .resampling_dataset import ResamplingDataset
4648

4749
from .iterators import (
4850
CountingIterator,
@@ -52,12 +54,14 @@
5254
)
5355

5456
__all__ = [
57+
'AppendTokenDataset',
5558
'BacktranslationDataset',
5659
'BaseWrapperDataset',
5760
'ColorizeDataset',
5861
'ConcatDataset',
5962
'ConcatSentencesDataset',
6063
'CountingIterator',
64+
'DenoisingDataset',
6165
'Dictionary',
6266
'EpochBatchIterator',
6367
'FairseqDataset',
@@ -83,9 +87,10 @@
8387
'PrependDataset',
8488
'PrependTokenDataset',
8589
'ReplaceDataset',
90+
'RollDataset',
8691
'FileAudioDataset',
8792
'RawLabelDataset',
88-
'ResamplingDataset'
93+
'ResamplingDataset',
8994
'RightPadDataset',
9095
'RoundRobinZipDatasets',
9196
'ShardedDataset',

fairseq/data/append_token_dataset.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import torch
8+
9+
from . import BaseWrapperDataset
10+
11+
12+
class AppendTokenDataset(BaseWrapperDataset):
13+
14+
def __init__(self, dataset, token=None):
15+
super().__init__(dataset)
16+
self.token = token
17+
if token is not None:
18+
self._sizes = np.array(dataset.sizes) + 1
19+
else:
20+
self._sizes = dataset.sizes
21+
22+
def __getitem__(self, idx):
23+
item = self.dataset[idx]
24+
if self.token is not None:
25+
item = torch.cat([item, item.new([self.token])])
26+
return item
27+
28+
@property
29+
def sizes(self):
30+
return self._sizes
31+
32+
def num_tokens(self, index):
33+
n = self.dataset.num_tokens(index)
34+
if self.token is not None:
35+
n += 1
36+
return n
37+
38+
def size(self, index):
39+
n = self.dataset.size(index)
40+
if self.token is not None:
41+
n += 1
42+
return n

0 commit comments

Comments
 (0)