Skip to content
This repository was archived by the owner on Jul 4, 2023. It is now read-only.

Commit 1766cc3

Browse files
authored
Merge pull request #34 from PetrochukM/dataset
Add set operation to dataset
2 parents 9925127 + 6f71f1d commit 1766cc3

File tree

3 files changed

+107
-36
lines changed

3 files changed

+107
-36
lines changed

tests/datasets/test_dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import random
23

34
from torchnlp.datasets import Dataset
45

@@ -24,6 +25,25 @@ def test_dataset_get_column():
2425
dataset['c']
2526

2627

28+
def test_dataset_set_column():
29+
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
30+
31+
# Regular column update
32+
dataset['a'] = ['aa', 'aaa']
33+
assert dataset['a'] == ['aa', 'aaa']
34+
35+
# To Little
36+
dataset['b'] = ['b']
37+
assert dataset['b'] == ['b', None]
38+
39+
# Too many
40+
dataset['c'] = ['c', 'cc', 'ccc']
41+
assert dataset['c'] == ['c', 'cc', 'ccc']
42+
43+
# Smoke (regression test)
44+
random.shuffle(dataset)
45+
46+
2747
def test_dataset_get_row():
2848
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
2949
assert dataset[0] == {'a': 'a', 'b': 'b'}
@@ -32,6 +52,20 @@ def test_dataset_get_row():
3252
dataset[2]
3353

3454

55+
def test_dataset_set_row():
56+
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
57+
dataset[0] = {'c': 'c'}
58+
assert dataset['c'] == ['c', None]
59+
assert dataset['a'] == [None, 'aa']
60+
61+
dataset[0:2] = [{'d': 'd'}, {'d': 'dd'}]
62+
assert dataset[0] == {'d': 'd'}
63+
assert dataset[1] == {'d': 'dd'}
64+
65+
with pytest.raises(IndexError):
66+
dataset[2] = {'c': 'c'}
67+
68+
3569
def test_dataset_equality():
3670
dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])
3771
other_dataset = Dataset([{'a': 'a', 'b': 'b'}, {'a': 'aa', 'b': 'bb'}])

torchnlp/datasets/dataset.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,52 @@ def __getitem__(self, key):
3939
# Given an column string return list of column values.
4040
if isinstance(key, str):
4141
if key not in self.columns:
42-
raise AttributeError
42+
raise AttributeError('Key not in columns.')
4343
return [row[key] if key in row else None for row in self.rows]
4444
# Given an row integer return a object of row values.
4545
elif isinstance(key, (int, slice)):
4646
return self.rows[key]
4747
else:
48-
raise TypeError("Invalid argument type.")
48+
raise TypeError('Invalid argument type.')
49+
50+
def __setitem__(self, key, item):
51+
"""
52+
Set a column or row for a dataset.
53+
54+
Args:
55+
key (str or int): String referencing a column or integer referencing a row
56+
item (list or dict): Column or rows to set in the dataset.
57+
"""
58+
if isinstance(key, str):
59+
column = item
60+
self.columns.add(key)
61+
if len(column) > len(self.rows):
62+
for i, value in enumerate(column):
63+
if i < len(self.rows):
64+
self.rows[i][key] = value
65+
else:
66+
self.rows.append({key: value})
67+
else:
68+
for i, row in enumerate(self.rows):
69+
if i < len(column):
70+
self.rows[i][key] = column[i]
71+
else:
72+
self.rows[i][key] = None
73+
elif isinstance(key, slice):
74+
rows = item
75+
for row in rows:
76+
if not isinstance(row, dict):
77+
raise ValueError('Row must be a dict.')
78+
self.columns.update(row.keys())
79+
self.rows[key] = rows
80+
elif isinstance(key, int):
81+
row = item
82+
if not isinstance(row, dict):
83+
raise ValueError('Row must be a dict.')
84+
self.columns.update(row.keys())
85+
self.rows[key] = row
86+
else:
87+
raise TypeError('Invalid argument type.')
4988

5089
def __len__(self):
5190
return len(self.rows)

torchnlp/word_to_vector/bpemb.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
11
from torchnlp.word_to_vector.pretrained_word_vectors import _PretrainedWordVectors
22

3-
43
# List of all 275 supported languages from http://cosyne.h-its.org/bpemb/data/
54
SUPPORTED_LANGUAGES = [
6-
'ab', 'ace', 'ady', 'af', 'ak', 'als', 'am', 'an', 'ang', 'ar', 'arc',
7-
'arz', 'as', 'ast', 'atj', 'av', 'ay', 'az', 'azb', 'ba', 'bar', 'bcl',
8-
'be', 'bg', 'bi', 'bjn', 'bm', 'bn', 'bo', 'bpy', 'br', 'bs', 'bug', 'bxr',
9-
'ca', 'cdo', 'ce', 'ceb', 'ch', 'chr', 'chy', 'ckb', 'co', 'cr', 'crh',
10-
'cs', 'csb', 'cu', 'cv', 'cy', 'da', 'de', 'din', 'diq', 'dsb', 'dty', 'dv',
11-
'dz', 'ee', 'el', 'en', 'eo', 'es', 'et', 'eu', 'ext', 'fa', 'ff', 'fi',
12-
'fj', 'fo', 'fr', 'frp', 'frr', 'fur', 'fy', 'ga', 'gag', 'gan', 'gd', 'gl',
13-
'glk', 'gn', 'gom', 'got', 'gu', 'gv', 'ha', 'hak', 'haw', 'he', 'hi',
14-
'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ig', 'ik', 'ilo',
15-
'io', 'is', 'it', 'iu', 'ja', 'jam', 'jbo', 'jv', 'ka', 'kaa', 'kab', 'kbd',
16-
'kbp', 'kg', 'ki', 'kk', 'kl', 'km', 'kn', 'ko', 'koi', 'krc', 'ks', 'ksh',
17-
'ku', 'kv', 'kw', 'ky', 'la', 'lad', 'lb', 'lbe', 'lez', 'lg', 'li', 'lij',
18-
'lmo', 'ln', 'lo', 'lrc', 'lt', 'ltg', 'lv', 'mai', 'mdf', 'mg', 'mh',
19-
'mhr', 'mi', 'min', 'mk', 'ml', 'mn', 'mr', 'mrj', 'ms', 'mt', 'mwl', 'my',
20-
'myv', 'mzn', 'na', 'nap', 'nds', 'ne', 'new', 'ng', 'nl', 'nn', 'no',
21-
'nov', 'nrm', 'nso', 'nv', 'ny', 'oc', 'olo', 'om', 'or', 'os', 'pa', 'pag',
22-
'pam', 'pap', 'pcd', 'pdc', 'pfl', 'pi', 'pih', 'pl', 'pms', 'pnb', 'pnt',
23-
'ps', 'pt', 'qu', 'rm', 'rmy', 'rn', 'ro', 'ru', 'rue', 'rw', 'sa', 'sah',
24-
'sc', 'scn', 'sco', 'sd', 'se', 'sg', 'sh', 'si', 'sk', 'sl', 'sm', 'sn',
25-
'so', 'sq', 'sr', 'srn', 'ss', 'st', 'stq', 'su', 'sv', 'sw', 'szl', 'ta',
26-
'tcy', 'te', 'tet', 'tg', 'th', 'ti', 'tk', 'tl', 'tn', 'to', 'tpi', 'tr',
27-
'ts', 'tt', 'tum', 'tw', 'ty', 'tyv', 'udm', 'ug', 'uk', 'ur', 'uz', 've',
28-
'vec', 'vep', 'vi', 'vls', 'vo', 'wa', 'war', 'wo', 'wuu', 'xal', 'xh',
29-
'xmf', 'yi', 'yo', 'za', 'zea', 'zh', 'zu'
5+
'ab', 'ace', 'ady', 'af', 'ak', 'als', 'am', 'an', 'ang', 'ar', 'arc', 'arz', 'as', 'ast',
6+
'atj', 'av', 'ay', 'az', 'azb', 'ba', 'bar', 'bcl', 'be', 'bg', 'bi', 'bjn', 'bm', 'bn', 'bo',
7+
'bpy', 'br', 'bs', 'bug', 'bxr', 'ca', 'cdo', 'ce', 'ceb', 'ch', 'chr', 'chy', 'ckb', 'co',
8+
'cr', 'crh', 'cs', 'csb', 'cu', 'cv', 'cy', 'da', 'de', 'din', 'diq', 'dsb', 'dty', 'dv', 'dz',
9+
'ee', 'el', 'en', 'eo', 'es', 'et', 'eu', 'ext', 'fa', 'ff', 'fi', 'fj', 'fo', 'fr', 'frp',
10+
'frr', 'fur', 'fy', 'ga', 'gag', 'gan', 'gd', 'gl', 'glk', 'gn', 'gom', 'got', 'gu', 'gv', 'ha',
11+
'hak', 'haw', 'he', 'hi', 'hif', 'hr', 'hsb', 'ht', 'hu', 'hy', 'ia', 'id', 'ie', 'ig', 'ik',
12+
'ilo', 'io', 'is', 'it', 'iu', 'ja', 'jam', 'jbo', 'jv', 'ka', 'kaa', 'kab', 'kbd', 'kbp', 'kg',
13+
'ki', 'kk', 'kl', 'km', 'kn', 'ko', 'koi', 'krc', 'ks', 'ksh', 'ku', 'kv', 'kw', 'ky', 'la',
14+
'lad', 'lb', 'lbe', 'lez', 'lg', 'li', 'lij', 'lmo', 'ln', 'lo', 'lrc', 'lt', 'ltg', 'lv',
15+
'mai', 'mdf', 'mg', 'mh', 'mhr', 'mi', 'min', 'mk', 'ml', 'mn', 'mr', 'mrj', 'ms', 'mt', 'mwl',
16+
'my', 'myv', 'mzn', 'na', 'nap', 'nds', 'ne', 'new', 'ng', 'nl', 'nn', 'no', 'nov', 'nrm',
17+
'nso', 'nv', 'ny', 'oc', 'olo', 'om', 'or', 'os', 'pa', 'pag', 'pam', 'pap', 'pcd', 'pdc',
18+
'pfl', 'pi', 'pih', 'pl', 'pms', 'pnb', 'pnt', 'ps', 'pt', 'qu', 'rm', 'rmy', 'rn', 'ro', 'ru',
19+
'rue', 'rw', 'sa', 'sah', 'sc', 'scn', 'sco', 'sd', 'se', 'sg', 'sh', 'si', 'sk', 'sl', 'sm',
20+
'sn', 'so', 'sq', 'sr', 'srn', 'ss', 'st', 'stq', 'su', 'sv', 'sw', 'szl', 'ta', 'tcy', 'te',
21+
'tet', 'tg', 'th', 'ti', 'tk', 'tl', 'tn', 'to', 'tpi', 'tr', 'ts', 'tt', 'tum', 'tw', 'ty',
22+
'tyv', 'udm', 'ug', 'uk', 'ur', 'uz', 've', 'vec', 'vep', 'vi', 'vls', 'vo', 'wa', 'war', 'wo',
23+
'wuu', 'xal', 'xh', 'xmf', 'yi', 'yo', 'za', 'zea', 'zh', 'zu'
3024
]
3125

3226
# All supported vector dimensionalities for which embeddings were trained
@@ -40,6 +34,11 @@ class BPEmb(_PretrainedWordVectors):
4034
"""
4135
Byte-Pair Encoding (BPE) embeddings trained on Wikipedia for 275 languages
4236
37+
A collection of pre-trained subword unit embeddings in 275 languages, based
38+
on Byte-Pair Encoding (BPE). In an evaluation using fine-grained entity typing as testbed,
39+
BPEmb performs competitively, and for some languages better than alternative subword
40+
approaches, while requiring vastly fewer resources and no tokenization.
41+
4342
References:
4443
* https://arxiv.org/abs/1710.02187
4544
* https://github.com/bheinzerling/bpemb
@@ -78,24 +77,23 @@ def __init__(self, language='en', dim=300, merge_ops=50000, **kwargs):
7877
# Check if all parameters are valid
7978
if language not in SUPPORTED_LANGUAGES:
8079
raise ValueError(("Language '%s' not supported. Use one of the "
81-
"following options instead:\n%s"
82-
) % (language, SUPPORTED_LANGUAGES))
80+
"following options instead:\n%s") % (language, SUPPORTED_LANGUAGES))
8381
if dim not in SUPPORTED_DIMS:
8482
raise ValueError(("Embedding dimensionality of '%d' not supported. "
85-
"Use one of the following options instead:\n%s"
86-
) % (dim, SUPPORTED_DIMS))
83+
"Use one of the following options instead:\n%s") % (dim,
84+
SUPPORTED_DIMS))
8785
if merge_ops not in SUPPORTED_MERGE_OPS:
8886
raise ValueError(("Number of '%d' merge operations not supported. "
89-
"Use one of the following options instead:\n%s"
90-
) % (merge_ops, SUPPORTED_MERGE_OPS))
87+
"Use one of the following options instead:\n%s") %
88+
(merge_ops, SUPPORTED_MERGE_OPS))
9189

9290
format_map = {'language': language, 'merge_ops': merge_ops, 'dim': dim}
9391

9492
# Assemble file name to locally store embeddings under
9593
name = self.file_name.format_map(format_map)
9694
# Assemble URL to download the embeddings form
97-
url = (self.url_base.format_map(format_map) +
98-
self.file_name.format_map(format_map) +
99-
self.zip_extension)
95+
url = (
96+
self.url_base.format_map(format_map) + self.file_name.format_map(format_map) +
97+
self.zip_extension)
10098

10199
super(BPEmb, self).__init__(name, url=url, **kwargs)

0 commit comments

Comments
 (0)