Skip to content

Commit 8679339

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Deprecate aggregate_logging_outputs API (use reduce_metrics instead) (#1611)
Summary: Pull Request resolved: #1611 Pull Request resolved: fairinternal/fairseq-py#974 Differential Revision: D19292402 Pulled By: myleott fbshipit-source-id: d51327584e048d3e39c133e9ef57a791e0329a66
1 parent 0ce722d commit 8679339

23 files changed

+444
-429
lines changed

fairseq/criterions/adaptive_loss.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
import math
7+
88
import torch.nn.functional as F
99

10-
from fairseq import utils
11-
from . import FairseqCriterion, register_criterion
10+
from fairseq import metrics, utils
11+
from fairseq.criterions import FairseqCriterion, register_criterion
1212

1313

1414
@register_criterion('adaptive_loss')
@@ -74,28 +74,24 @@ def forward(self, model, sample, reduce=True):
7474
return loss, sample_size, logging_output
7575

7676
@staticmethod
77-
def aggregate_logging_outputs(logging_outputs):
77+
def reduce_metrics(logging_outputs) -> None:
7878
"""Aggregate logging outputs from data parallel training."""
7979
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
8080
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
81-
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
8281
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
83-
agg_output = {
84-
'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
85-
'nll_loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
86-
'ntokens': ntokens,
87-
'nsentences': nsentences,
88-
'sample_size': sample_size,
89-
}
82+
83+
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
9084
if sample_size != ntokens:
91-
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.
92-
return agg_output
85+
metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
86+
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
87+
else:
88+
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))
9389

9490
@staticmethod
9591
def logging_outputs_can_be_summed() -> bool:
9692
"""
9793
Whether the logging outputs returned by `forward` can be summed
98-
across workers prior to calling `aggregate_logging_outputs`.
99-
Setting this to True will improves distributed training speed.
94+
across workers prior to calling `reduce_metrics`. Setting this
95+
to True will improves distributed training speed.
10096
"""
10197
return True

fairseq/criterions/binary_cross_entropy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7+
78
import numpy as np
89
import torch
910
import torch.nn.functional as F
1011

1112
from fairseq import utils
12-
13-
from . import FairseqCriterion, register_criterion
13+
from fairseq.criterions import FairseqCriterion, register_criterion
1414

1515

1616
@register_criterion('binary_cross_entropy')

fairseq/criterions/composite_loss.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import nn
77

88
from fairseq import utils
9-
from . import FairseqCriterion, register_criterion
9+
from fairseq.criterions import FairseqCriterion, register_criterion
1010

1111

1212
@register_criterion('composite_loss')
@@ -88,4 +88,8 @@ def forward(self, model, sample, reduce=True):
8888
def aggregate_logging_outputs(logging_outputs):
8989
return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs)
9090

91+
@staticmethod
92+
def reduce_metrics(logging_outputs) -> None:
93+
underlying_criterion.__class__.reduce_metrics(logging_outputs)
94+
9195
return _CompositeLoss(args, task, underlying_criterion)

fairseq/criterions/cross_entropy.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7-
import torch.nn.functional as F
87

9-
from fairseq import utils
8+
import torch.nn.functional as F
109

11-
from . import FairseqCriterion, register_criterion
10+
from fairseq import metrics, utils
11+
from fairseq.criterions import FairseqCriterion, register_criterion
1212

1313

1414
@register_criterion('cross_entropy')
@@ -30,7 +30,6 @@ def forward(self, model, sample, reduce=True):
3030
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
3131
logging_output = {
3232
'loss': utils.item(loss.data) if reduce else loss.data,
33-
'nll_loss': utils.item(loss.data) if reduce else loss.data,
3433
'ntokens': sample['ntokens'],
3534
'nsentences': sample['target'].size(0),
3635
'sample_size': sample_size,
@@ -50,27 +49,24 @@ def compute_loss(self, model, net_output, sample, reduce=True):
5049
return loss, loss
5150

5251
@staticmethod
53-
def aggregate_logging_outputs(logging_outputs):
52+
def reduce_metrics(logging_outputs) -> None:
5453
"""Aggregate logging outputs from data parallel training."""
5554
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
5655
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
57-
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
5856
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
59-
agg_output = {
60-
'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
61-
'ntokens': ntokens,
62-
'nsentences': nsentences,
63-
'sample_size': sample_size,
64-
}
57+
58+
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
6559
if sample_size != ntokens:
66-
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
67-
return agg_output
60+
metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
61+
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
62+
else:
63+
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))
6864

6965
@staticmethod
7066
def logging_outputs_can_be_summed() -> bool:
7167
"""
7268
Whether the logging outputs returned by `forward` can be summed
73-
across workers prior to calling `aggregate_logging_outputs`.
74-
Setting this to True will improves distributed training speed.
69+
across workers prior to calling `reduce_metrics`. Setting this
70+
to True will improves distributed training speed.
7571
"""
7672
return True

fairseq/criterions/fairseq_criterion.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Any, Dict, List
7+
68
from torch.nn.modules.loss import _Loss
79

10+
from fairseq import metrics, utils
11+
812

913
class FairseqCriterion(_Loss):
1014

@@ -34,15 +38,34 @@ def forward(self, model, sample, reduce=True):
3438
raise NotImplementedError
3539

3640
@staticmethod
37-
def aggregate_logging_outputs(logging_outputs):
41+
def aggregate_logging_outputs(
42+
logging_outputs: List[Dict[str, Any]],
43+
) -> Dict[str, Any]:
3844
"""Aggregate logging outputs from data parallel training."""
45+
utils.deprecation_warning(
46+
'The aggregate_logging_outputs API is deprecated. '
47+
'Please use the reduce_metrics API instead.'
48+
)
3949
raise NotImplementedError
4050

51+
@classmethod
52+
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
53+
"""Aggregate logging outputs from data parallel training."""
54+
utils.deprecation_warning(
55+
'Criterions should implement the reduce_metrics API. '
56+
'Falling back to deprecated aggregate_logging_outputs API.'
57+
)
58+
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
59+
for k, v in agg_logging_outputs.items():
60+
if k in {'nsentences', 'ntokens', 'sample_size'}:
61+
continue
62+
metrics.log_scalar(k, v)
63+
4164
@staticmethod
4265
def logging_outputs_can_be_summed() -> bool:
4366
"""
4467
Whether the logging outputs returned by `forward` can be summed
45-
across workers prior to calling `aggregate_logging_outputs`.
46-
Setting this to True will improves distributed training speed.
68+
across workers prior to calling `reduce_metrics`. Setting this
69+
to True will improves distributed training speed.
4770
"""
4871
return False

fairseq/criterions/label_smoothed_cross_entropy.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55

66
import math
77

8-
from fairseq import utils
9-
10-
from . import FairseqCriterion, register_criterion
8+
from fairseq import metrics, utils
9+
from fairseq.criterions import FairseqCriterion, register_criterion
1110

1211

1312
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
@@ -76,24 +75,22 @@ def compute_loss(self, model, net_output, sample, reduce=True):
7675
return loss, nll_loss
7776

7877
@staticmethod
79-
def aggregate_logging_outputs(logging_outputs):
78+
def reduce_metrics(logging_outputs) -> None:
8079
"""Aggregate logging outputs from data parallel training."""
80+
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
81+
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
8182
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
82-
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
8383
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
84-
return {
85-
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
86-
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
87-
'ntokens': ntokens,
88-
'nsentences': nsentences,
89-
'sample_size': sample_size,
90-
}
84+
85+
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
86+
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
87+
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
9188

9289
@staticmethod
9390
def logging_outputs_can_be_summed() -> bool:
9491
"""
9592
Whether the logging outputs returned by `forward` can be summed
96-
across workers prior to calling `aggregate_logging_outputs`.
97-
Setting this to True will improves distributed training speed.
93+
across workers prior to calling `reduce_metrics`. Setting this
94+
to True will improves distributed training speed.
9895
"""
9996
return True

fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import math
77

8-
from fairseq import utils
8+
from fairseq import metrics, utils
9+
from fairseq.criterions import register_criterion
910

1011
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
11-
from . import register_criterion
1212

1313

1414
@register_criterion('label_smoothed_cross_entropy_with_alignment')
@@ -75,25 +75,24 @@ def compute_alignment_loss(self, sample, net_output):
7575
return loss
7676

7777
@staticmethod
78-
def aggregate_logging_outputs(logging_outputs):
78+
def reduce_metrics(logging_outputs) -> None:
7979
"""Aggregate logging outputs from data parallel training."""
80+
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
81+
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
82+
alignment_loss_sum = sum(log.get('alignment_loss', 0) for log in logging_outputs)
8083
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
81-
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
8284
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
83-
return {
84-
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
85-
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
86-
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
87-
'ntokens': ntokens,
88-
'nsentences': nsentences,
89-
'sample_size': sample_size,
90-
}
85+
86+
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
87+
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
88+
metrics.log_scalar('alignment_loss', alignment_loss_sum / sample_size / math.log(2), sample_size, round=3)
89+
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
9190

9291
@staticmethod
9392
def logging_outputs_can_be_summed() -> bool:
9493
"""
9594
Whether the logging outputs returned by `forward` can be summed
96-
across workers prior to calling `aggregate_logging_outputs`.
97-
Setting this to True will improves distributed training speed.
95+
across workers prior to calling `reduce_metrics`. Setting this
96+
to True will improves distributed training speed.
9897
"""
9998
return True

fairseq/criterions/legacy_masked_lm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
import math
7+
88
import torch
99
import torch.nn.functional as F
1010

1111
from fairseq import utils
12-
from . import FairseqCriterion, register_criterion
12+
from fairseq.criterions import FairseqCriterion, register_criterion
1313

1414

1515
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
@@ -150,7 +150,7 @@ def aggregate_logging_outputs(logging_outputs):
150150
def logging_outputs_can_be_summed() -> bool:
151151
"""
152152
Whether the logging outputs returned by `forward` can be summed
153-
across workers prior to calling `aggregate_logging_outputs`.
154-
Setting this to True will improves distributed training speed.
153+
across workers prior to calling `reduce_metrics`. Setting this
154+
to True will improves distributed training speed.
155155
"""
156156
return True

fairseq/criterions/masked_lm.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
import torch
99
import torch.nn.functional as F
1010

11-
from fairseq import utils
12-
13-
from . import FairseqCriterion, register_criterion
11+
from fairseq import metrics, utils
12+
from fairseq.criterions import FairseqCriterion, register_criterion
1413

1514

1615
@register_criterion('masked_lm')
@@ -19,11 +18,9 @@ class MaskedLmLoss(FairseqCriterion):
1918
Implementation for the loss used in masked language model (MLM) training.
2019
"""
2120

22-
def __init__(self, args, task):
23-
super().__init__(args, task)
24-
2521
def forward(self, model, sample, reduce=True):
2622
"""Compute the loss for the given sample.
23+
2724
Returns a tuple with three elements:
2825
1) the loss
2926
2) the sample size, which is used as the denominator for the gradient
@@ -56,35 +53,26 @@ def forward(self, model, sample, reduce=True):
5653
)
5754
logging_output = {
5855
'loss': utils.item(loss.data) if reduce else loss.data,
59-
'nll_loss': utils.item(loss.data) if reduce else loss.data,
6056
'ntokens': sample['ntokens'],
6157
'nsentences': sample['nsentences'],
6258
'sample_size': sample_size,
6359
}
6460
return loss, sample_size, logging_output
6561

6662
@staticmethod
67-
def aggregate_logging_outputs(logging_outputs):
63+
def reduce_metrics(logging_outputs) -> None:
6864
"""Aggregate logging outputs from data parallel training."""
69-
loss = sum(log.get('loss', 0) for log in logging_outputs)
70-
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
71-
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
65+
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
7266
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
7367

74-
agg_output = {
75-
'loss': loss / sample_size / math.log(2),
76-
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
77-
'ntokens': ntokens,
78-
'nsentences': nsentences,
79-
'sample_size': sample_size,
80-
}
81-
return agg_output
68+
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
69+
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))
8270

8371
@staticmethod
8472
def logging_outputs_can_be_summed() -> bool:
8573
"""
8674
Whether the logging outputs returned by `forward` can be summed
87-
across workers prior to calling `aggregate_logging_outputs`.
88-
Setting this to True will improves distributed training speed.
75+
across workers prior to calling `reduce_metrics`. Setting this
76+
to True will improves distributed training speed.
8977
"""
9078
return True

0 commit comments

Comments
 (0)