Skip to content

Commit 4a7cd58

Browse files
alexeibfacebook-github-bot
authored andcommitted
set numpy seed explicitly + other minor fixes (#850)
Summary: not setting the numpy seed explicitly at the beginning was an extremely annoying bug to find. it it caused different gpus to have a different view of data if some randomization was used in the dataset (e.g. subsample dataset) Pull Request resolved: fairinternal/fairseq-py#850 Differential Revision: D17085006 Pulled By: alexeib fbshipit-source-id: 62bb2116369fb703df878e6bc24c06f1ea4e75a0
1 parent 8777465 commit 4a7cd58

File tree

3 files changed

+30
-9
lines changed

3 files changed

+30
-9
lines changed

fairseq/data/replace_dataset.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,30 @@
77

88

99
class ReplaceDataset(BaseWrapperDataset):
10-
def __init__(self, dataset, replace_map, offset=0):
10+
"""Replaces tokens found in the dataset by a specified replacement token
11+
12+
Args:
13+
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
14+
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
15+
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
16+
as many as the number of objects returned by the underlying dataset __getitem__ method.
17+
"""
18+
19+
def __init__(self, dataset, replace_map, offsets):
1120
super().__init__(dataset)
1221
assert len(replace_map) > 0
1322
self.replace_map = replace_map
14-
self.offset = offset
23+
self.offsets = offsets
1524

1625
def __getitem__(self, index):
1726
item = self.dataset[index]
1827
is_tuple = isinstance(item, tuple)
19-
src = item[0] if is_tuple else item
28+
srcs = item if is_tuple else [item]
2029

21-
for k, v in self.replace_map.items():
22-
src_off = src[self.offset:]
23-
src_off.masked_fill_(src_off == k, v)
30+
for offset, src in zip(self.offsets, srcs):
31+
for k, v in self.replace_map.items():
32+
src_off = src[offset:] if offset >= 0 else src[:offset]
33+
src_off.masked_fill_(src_off == k, v)
2434

25-
item = tuple((src,) + item[1:]) if is_tuple else src
35+
item = srcs if is_tuple else srcs[0]
2636
return item

fairseq/data/subsample_dataset.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,24 @@
99

1010

1111
class SubsampleDataset(BaseWrapperDataset):
12+
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
13+
14+
Args:
15+
dataset (~torch.utils.data.Dataset): dataset to subsample
16+
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
17+
"""
18+
1219
def __init__(self, dataset, size_ratio):
1320
super().__init__(dataset)
1421
assert size_ratio < 1
1522
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
1623
self.indices = np.random.choice(
17-
range(len(self.dataset)), self.actual_size, replace=False
24+
list(range(len(self.dataset))), self.actual_size, replace=False
1825
)
1926
print(
20-
"subsampled dataset from {} to {} (ratio={})".format(len(self.dataset), self.actual_size, size_ratio)
27+
"subsampled dataset from {} to {} (ratio={})".format(
28+
len(self.dataset), self.actual_size, size_ratio
29+
)
2130
)
2231

2332
def __getitem__(self, index):

train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import collections
1111
import math
12+
import numpy as np
1213
import random
1314

1415
import torch
@@ -28,6 +29,7 @@ def main(args, init_distributed=False):
2829
# Initialize CUDA and distributed training
2930
if torch.cuda.is_available() and not args.cpu:
3031
torch.cuda.set_device(args.device_id)
32+
np.random.seed(args.seed)
3133
torch.manual_seed(args.seed)
3234
if init_distributed:
3335
args.distributed_rank = distributed_utils.distributed_init(args)

0 commit comments

Comments
 (0)