Skip to content

Commit aa79bb9

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Use 1-based indexing for epochs everywhere (#1053)
Summary: We are somewhat inconsistent in whether we're using 0-based or 1-based indexing for epochs. This should fix things to be 0-based internally, with logging and checkpoint naming still using 1-based indexing. Pull Request resolved: fairinternal/fairseq-py#1053 Reviewed By: spencerp Differential Revision: D20160715 Pulled By: myleott fbshipit-source-id: 4ed94f9c371e1bfe29bcfa087fa6756507d6e627
1 parent 4171b83 commit aa79bb9

26 files changed

+63
-116
lines changed

examples/roberta/commonsense_qa/commonsense_qa_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def setup_task(cls, args, **kwargs):
6666

6767
return cls(args, vocab)
6868

69-
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
69+
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
7070
"""Load a given dataset split.
7171
7272
Args:

examples/roberta/wsc/wsc_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space)
101101
mask[mask_start:mask_start + mask_size] = 1
102102
return toks, mask
103103

104-
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
104+
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
105105
"""Load a given dataset split.
106106
107107
Args:
@@ -281,7 +281,7 @@ def setup_task(cls, args, **kwargs):
281281

282282
return cls(args, vocab)
283283

284-
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
284+
def load_dataset(self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs):
285285
"""Load a given dataset split.
286286
287287
Args:

fairseq/benchmark/dummy_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setup_task(cls, args, **kwargs):
4242

4343
return cls(args, dictionary)
4444

45-
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
45+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
4646
"""Load a given dataset split.
4747
Args:
4848
split (str): name of the split (e.g., train, valid, test)

fairseq/benchmark/dummy_masked_lm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def setup_task(cls, args, **kwargs):
5353

5454
return cls(args, dictionary)
5555

56-
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
56+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
5757
"""Load a given dataset split.
5858
Args:
5959
split (str): name of the split (e.g., train, valid, test)

fairseq/checkpoint_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def is_better(a, b):
109109
if os.path.lexists(old_chk):
110110
os.remove(old_chk)
111111

112+
112113
def load_checkpoint(args, trainer, **passthrough_args):
113114
"""
114115
Load a checkpoint and restore the training iterator.
@@ -150,7 +151,7 @@ def load_checkpoint(args, trainer, **passthrough_args):
150151
epoch_itr.load_state_dict(itr_state)
151152
else:
152153
epoch_itr = trainer.get_train_iterator(
153-
epoch=0, load_dataset=True, **passthrough_args
154+
epoch=1, load_dataset=True, **passthrough_args
154155
)
155156

156157
trainer.lr_step(epoch_itr.epoch)
@@ -349,6 +350,11 @@ def _upgrade_state_dict(state):
349350
state["args"].dataset_impl = "raw"
350351
elif getattr(state["args"], "lazy_load", False):
351352
state["args"].dataset_impl = "lazy"
353+
# epochs start at 1
354+
state["extra_state"]["train_iterator"]["epoch"] = max(
355+
getattr(state["extra_state"]["train_iterator"], "epoch", 1),
356+
1,
357+
)
352358

353359
# set any missing default values in the task, model or other registries
354360
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])

fairseq/data/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from .resampling_dataset import ResamplingDataset
3939
from .roll_dataset import RollDataset
4040
from .round_robin_zip_datasets import RoundRobinZipDatasets
41-
from .sharded_dataset import ShardedDataset
4241
from .sort_dataset import SortDataset
4342
from .strip_token_dataset import StripTokenDataset
4443
from .subsample_dataset import SubsampleDataset
@@ -96,7 +95,6 @@
9695
'ResamplingDataset',
9796
'RightPadDataset',
9897
'RoundRobinZipDatasets',
99-
'ShardedDataset',
10098
'ShardedIterator',
10199
'SortDataset',
102100
'StripTokenDataset',

fairseq/data/iterators.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,18 @@ def load_state_dict(self, state_dict):
100100

101101
class StreamingEpochBatchIterator(EpochBatchIterating):
102102
def __init__(
103-
self, dataset, epoch=0, num_shards=1, shard_id=0,
103+
self, dataset, epoch=1, num_shards=1, shard_id=0,
104104
):
105105
assert isinstance(dataset, torch.utils.data.IterableDataset)
106106
self.dataset = dataset
107-
self.epoch = epoch
107+
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
108108
self._current_epoch_iterator = None
109109
self.num_shards = num_shards
110110
self.shard_id = shard_id
111111

112112
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
113-
self.epoch += 1
113+
if self._current_epoch_iterator is not None and self.end_of_epoch():
114+
self.epoch += 1
114115
self.dataset.set_epoch(self.epoch)
115116
self._current_epoch_iterator = CountingIterator(
116117
iterable=ShardedIterator(
@@ -165,12 +166,12 @@ class EpochBatchIterator(EpochBatchIterating):
165166
loading. 0 means the data will be loaded in the main process
166167
(default: 0).
167168
epoch (int, optional): the epoch to start the iterator from
168-
(default: 0).
169+
(default: 1).
169170
"""
170171

171172
def __init__(
172173
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
173-
num_workers=0, epoch=0,
174+
num_workers=0, epoch=1,
174175
):
175176
assert isinstance(dataset, torch.utils.data.Dataset)
176177
self.dataset = dataset
@@ -181,7 +182,7 @@ def __init__(
181182
self.shard_id = shard_id
182183
self.num_workers = num_workers
183184

184-
self.epoch = epoch
185+
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
185186
self.shuffle = True
186187
self._cur_epoch_itr = None
187188
self._next_epoch_itr = None
@@ -204,7 +205,8 @@ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
204205
self._cur_epoch_itr = self._next_epoch_itr
205206
self._next_epoch_itr = None
206207
else:
207-
self.epoch += 1
208+
if self._cur_epoch_itr is not None and self.end_of_epoch():
209+
self.epoch += 1
208210
self._cur_epoch_itr = self._get_iterator_for_epoch(
209211
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
210212
)
@@ -244,6 +246,9 @@ def load_state_dict(self, state_dict):
244246
shuffle=state_dict.get('shuffle', True),
245247
offset=itr_pos,
246248
)
249+
if self._next_epoch_itr is None:
250+
# we finished the epoch, increment epoch counter
251+
self.epoch += 1
247252

248253
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False, offset=0):
249254

fairseq/data/resampling_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ResamplingDataset(BaseWrapperDataset):
3131
batch_by_size (bool): whether or not to batch by sequence length
3232
(default: True).
3333
seed (int): RNG seed to use (default: 0).
34-
epoch (int): starting epoch number (default: 0).
34+
epoch (int): starting epoch number (default: 1).
3535
"""
3636

3737
def __init__(
@@ -42,7 +42,7 @@ def __init__(
4242
size_ratio=1.0,
4343
batch_by_size=True,
4444
seed=0,
45-
epoch=0,
45+
epoch=1,
4646
):
4747
super().__init__(dataset)
4848

fairseq/data/sharded_dataset.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

fairseq/tasks/cross_lingual_lm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _load_single_lang_dataset(self, split, epoch):
102102

103103
paths = utils.split_paths(self.args.data)
104104
assert len(paths) > 0
105-
data_path = paths[epoch % len(paths)]
105+
data_path = paths[(epoch - 1) % len(paths)]
106106

107107
for k in itertools.count():
108108
split_k = split + (str(k) if k > 0 else '')
@@ -136,8 +136,9 @@ def _load_single_lang_dataset(self, split, epoch):
136136

137137
return dataset, sizes
138138

139-
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
139+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
140140
"""Load a given dataset split.
141+
141142
Args:
142143
split (str): name of the split (e.g., train, valid, test)
143144
"""
@@ -165,5 +166,5 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs):
165166

166167
self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
167168
logger.info('{} {} {} examples'.format(
168-
utils.split_paths(self.args.data)[epoch], split, len(self.datasets[split]))
169+
utils.split_paths(self.args.data)[epoch - 1], split, len(self.datasets[split]))
169170
)

0 commit comments

Comments
 (0)