Skip to content

Commit 5379461

Browse files
alexeibfacebook-github-bot
authored andcommitted
lm rescoring attempt (#1242)
Summary: CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/private/home/abaevski/fairseq-py-master python fairseq_cli/generate.py /checkpoint/henryzhou7/dataset/libri/960h/raw3/decoder --task audio_pretraining --seed 1 --nbest 1 --gen-subset dev_other --max-tokens 600000 --path ~/models/wav2vec2/vox_960h_seq2seq_10kwp.pt --labels 10k --remove-bpe 'wordpiece' --quiet --beam 50 --temperature 1 --scoring wer --lm-path /checkpoint/henryzhou7/wp_lm/transformer_raw3_adam_cosine2node/lr_1e-4_updatefreq_8/checkpoint_best.pt --lm-weight 1 results: no lm: 4.30577896347444 lm (1.5): 24.691650853889943 lm (1): 10.884539582804846 lm (0.5): 4.894205665744457 lm (0.25): 4.012853671917862 lm (0.1): 4.087637055489084 lm (0.05): 4.194788887144875 Pull Request resolved: fairinternal/fairseq-py#1242 Reviewed By: kahne Differential Revision: D23277386 Pulled By: alexeib fbshipit-source-id: 062f483bd45ddd2dd5ff24a8a35cc1c4f34ce6ab
1 parent b880744 commit 5379461

File tree

6 files changed

+62
-8
lines changed

6 files changed

+62
-8
lines changed

examples/speech_recognition/tasks/speech_recognition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def load_dataset(self, split, combine=False, **kwargs):
113113
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
114114
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
115115

116-
def build_generator(self, models, args):
116+
def build_generator(self, models, args, **unused):
117117
w2l_decoder = getattr(args, "w2l_decoder", None)
118118
if w2l_decoder == "viterbi":
119119
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

fairseq/options.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def add_generation_args(parser):
380380
help='if set, uses attention feedback to compute and print alignment to source tokens')
381381
group.add_argument('--print-step', action='store_true')
382382

383+
group.add_argument('--lm-path', default=None, type=str, metavar='PATH',
384+
help='path to lm checkpoint for lm fusion')
385+
group.add_argument('--lm-weight', default=0.0, type=float, metavar='N',
386+
help='weight for lm probs for lm fusion')
387+
383388
# arguments for iterative refinement generator
384389
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
385390
help='if > 0.0, it penalized early-stopping in decoding.')

fairseq/sequence_generator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(
3333
search_strategy=None,
3434
eos=None,
3535
symbols_to_strip_from_output=None,
36+
lm_model=None,
37+
lm_weight=1.0
3638
):
3739
"""Generates translations of a given source sentence.
3840
@@ -94,6 +96,11 @@ def __init__(
9496

9597
self.model.eval()
9698

99+
self.lm_model = lm_model
100+
self.lm_weight = lm_weight
101+
if self.lm_model is not None:
102+
self.lm_model.eval()
103+
97104
def cuda(self):
98105
self.model.cuda()
99106
return self
@@ -292,6 +299,15 @@ def _generate(
292299
incremental_states,
293300
self.temperature,
294301
)
302+
303+
if self.lm_model is not None:
304+
lm_out = self.lm_model(tokens[:, : step + 1])
305+
probs = self.lm_model.get_normalized_probs(
306+
lm_out, log_probs=True, sample=None
307+
)
308+
probs = probs[:, -1, :] * self.lm_weight
309+
lprobs += probs
310+
295311
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
296312

297313
lprobs[:, self.pad] = -math.inf # never select pad
@@ -820,9 +836,11 @@ def forward_decoder(
820836
avg_attn = attn
821837
else:
822838
avg_attn.add_(attn)
839+
823840
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
824841
self.models_size
825842
)
843+
826844
if avg_attn is not None:
827845
avg_attn.div_(self.models_size)
828846
return avg_probs, avg_attn

fairseq/tasks/translation_from_pretrained_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
8484
append_source_id=True
8585
)
8686

87-
def build_generator(self, models, args):
87+
def build_generator(self, models, args, **unused):
8888
if getattr(args, 'score_reference', False):
8989
from fairseq.sequence_scorer import SequenceScorer
9090
return SequenceScorer(

fairseq/tasks/translation_lev.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _full_mask(target_tokens):
128128
else:
129129
raise NotImplementedError
130130

131-
def build_generator(self, models, args):
131+
def build_generator(self, models, args, **unused):
132132
# add models input to match the API for SequenceGenerator
133133
from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
134134
return IterativeRefinementGenerator(

fairseq_cli/generate.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
Translate pre-processed data with a trained model.
88
"""
99

10+
import ast
11+
from itertools import chain
1012
import logging
1113
import math
1214
import os
@@ -78,17 +80,39 @@ def _main(args, output_file):
7880
src_dict = None
7981
tgt_dict = task.target_dictionary
8082

83+
overrides = ast.literal_eval(args.model_overrides)
84+
8185
# Load ensemble
8286
logger.info('loading model(s) from {}'.format(args.path))
8387
models, _model_args = checkpoint_utils.load_model_ensemble(
8488
utils.split_paths(args.path),
85-
arg_overrides=eval(args.model_overrides),
89+
arg_overrides=overrides,
8690
task=task,
8791
suffix=getattr(args, "checkpoint_suffix", ""),
8892
)
8993

94+
if args.lm_path is not None:
95+
overrides['data'] = args.data
96+
97+
try:
98+
lms, _ = checkpoint_utils.load_model_ensemble(
99+
[args.lm_path],
100+
arg_overrides=overrides,
101+
task=None,
102+
)
103+
except:
104+
logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same "
105+
f"as target dict and is located in the data dir ({args.data})")
106+
raise
107+
108+
assert len(lms) == 1
109+
else:
110+
lms = [None]
111+
90112
# Optimize ensemble for generation
91-
for model in models:
113+
for model in chain(models, lms):
114+
if model is None:
115+
continue
92116
model.prepare_for_inference_(args)
93117
if args.fp16:
94118
model.half()
@@ -124,7 +148,12 @@ def _main(args, output_file):
124148

125149
# Initialize generator
126150
gen_timer = StopwatchMeter()
127-
generator = task.build_generator(models, args)
151+
152+
extra_gen_cls_kwargs = {
153+
'lm_model': lms[0],
154+
'lm_weight': args.lm_weight
155+
}
156+
generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
128157

129158
# Handle tokenization and BPE
130159
tokenizer = encoders.build_tokenizer(args)
@@ -269,9 +298,11 @@ def decode_fn(x):
269298
if has_target:
270299
if args.bpe and not args.sacrebleu:
271300
if args.remove_bpe:
272-
logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
301+
logger.warning(
302+
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
273303
else:
274-
logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
304+
logger.warning(
305+
"If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
275306
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
276307
print(
277308
'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()),

0 commit comments

Comments
 (0)