1
1
import csv
2
2
import gc
3
3
import logging
4
+ import os
5
+ import re
4
6
import warnings
7
+ import zipfile
8
+ from urllib .request import urlretrieve
9
+ from collections import Counter , OrderedDict
5
10
6
11
import pandas as pd
7
12
import torch
11
16
from sklearn .preprocessing import MultiLabelBinarizer
12
17
from torch .nn .utils .rnn import pad_sequence
13
18
from torch .utils .data import Dataset
14
- from torchtext .vocab import build_vocab_from_iterator , pretrained_aliases , Vocab
15
19
from tqdm import tqdm
16
20
17
21
transformers .logging .set_verbosity_error ()
18
22
warnings .simplefilter (action = "ignore" , category = FutureWarning )
19
23
20
24
UNK = "<unk>"
21
25
PAD = "<pad>"
26
+ GLOVE_WORD_EMBEDDING = {
27
+ "glove.42B.300d" ,
28
+ "glove.840B.300d" ,
29
+ "glove.6B.50d" ,
30
+ "glove.6B.100d" ,
31
+ "glove.6B.200d" ,
32
+ "glove.6B.300d" ,
33
+ }
22
34
23
35
24
36
class TextDataset (Dataset ):
@@ -31,8 +43,7 @@ class TextDataset(Dataset):
31
43
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
32
44
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
33
45
the transformer-based pretrained language model. Defaults to None.
34
- word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
35
- map tokens to indices. Defaults to None.
46
+ word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
36
47
"""
37
48
38
49
def __init__ (
@@ -55,7 +66,7 @@ def __init__(
55
66
self .num_classes = len (self .classes )
56
67
self .label_binarizer = MultiLabelBinarizer ().fit ([classes ])
57
68
58
- if not isinstance (self .word_dict , Vocab ) ^ isinstance (self .tokenizer , transformers .PreTrainedTokenizerBase ):
69
+ if not isinstance (self .word_dict , dict ) ^ isinstance (self .tokenizer , transformers .PreTrainedTokenizerBase ):
59
70
raise ValueError ("Please specify exactly one of word_dict or tokenizer" )
60
71
61
72
def __len__ (self ):
@@ -71,7 +82,7 @@ def __getitem__(self, index):
71
82
else :
72
83
input_ids = self .tokenizer .encode (data ["text" ], add_special_tokens = False )
73
84
else :
74
- input_ids = [self .word_dict [ word ] for word in data ["text" ]]
85
+ input_ids = [self .word_dict . get ( word , self . word_dict [ UNK ]) for word in data ["text" ]]
75
86
return {
76
87
"text" : torch .LongTensor (input_ids [: self .max_seq_length ]),
77
88
"label" : torch .IntTensor (self .label_binarizer .transform ([data ["label" ]])[0 ]),
@@ -128,8 +139,7 @@ def get_dataset_loader(
128
139
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
129
140
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
130
141
the transformer-based pretrained language model. Defaults to None.
131
- word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
132
- map tokens to indices. Defaults to None.
142
+ word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
133
143
134
144
Returns:
135
145
torch.utils.data.DataLoader: A pytorch DataLoader.
@@ -154,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
154
164
Args:
155
165
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
156
166
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
167
+ tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
157
168
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
158
169
This is effective only when is_test=False. Defaults to False.
159
170
@@ -265,35 +276,34 @@ def load_or_build_text_dict(
265
276
):
266
277
"""Build or load the vocabulary from the training dataset or the predefined `vocab_file`.
267
278
The pretrained embedding can be either from a self-defined `embed_file` or from one of
268
- the vectors defined in torchtext.vocab.pretrained_aliases
269
- (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py).
279
+ the vectors: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, or `glove.840B.300d`.
270
280
271
281
Args:
272
282
dataset (list): List of training instances with index, label, and tokenized text.
273
283
vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
274
284
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
275
- embed_file (str): Path to a file holding pre-trained embeddings.
285
+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None .
276
286
embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
277
287
silent (bool, optional): Enable silent mode. Defaults to False.
278
288
normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.
279
289
280
290
Returns:
281
- tuple[torchtext.vocab.Vocab , torch.Tensor]: A vocab object which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
291
+ tuple[dict , torch.Tensor]: A dictionary which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
282
292
"""
283
293
if vocab_file :
284
294
logging .info (f"Load vocab from { vocab_file } " )
285
295
with open (vocab_file , "r" ) as fp :
286
296
vocab_list = [[vocab .strip () for vocab in fp .readlines ()]]
287
297
# Keep PAD index 0 to align `padding_idx` of
288
298
# class Embedding in libmultilabel.nn.networks.modules.
289
- vocabs = build_vocab_from_iterator (vocab_list , min_freq = 1 , specials = [PAD , UNK ])
299
+ word_dict = _build_word_dict (vocab_list , min_vocab_freq = 1 , specials = [PAD , UNK ])
290
300
else :
291
301
vocab_list = [set (data ["text" ]) for data in dataset ]
292
- vocabs = build_vocab_from_iterator (vocab_list , min_freq = min_vocab_freq , specials = [PAD , UNK ])
293
- vocabs . set_default_index ( vocabs [ UNK ])
294
- logging .info (f"Read { len (vocabs )} vocabularies." )
302
+ word_dict = _build_word_dict (vocab_list , min_vocab_freq = min_vocab_freq , specials = [PAD , UNK ])
303
+
304
+ logging .info (f"Read { len (word_dict )} vocabularies." )
295
305
296
- embedding_weights = get_embedding_weights_from_file (vocabs , embed_file , silent , embed_cache_dir )
306
+ embedding_weights = get_embedding_weights_from_file (word_dict , embed_file , silent , embed_cache_dir )
297
307
298
308
if normalize_embed :
299
309
# To have better precision for calculating the normalization, we convert the original
@@ -306,7 +316,41 @@ def load_or_build_text_dict(
306
316
embedding_weights [i ] = vector / float (torch .linalg .norm (vector ) + 1e-6 )
307
317
embedding_weights = embedding_weights .float ()
308
318
309
- return vocabs , embedding_weights
319
+ return word_dict , embedding_weights
320
+
321
+
322
+ def _build_word_dict (vocab_list , min_vocab_freq = 1 , specials = None ):
323
+ r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
324
+ (https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)
325
+
326
+ Args:
327
+ vocab_list: List of words.
328
+ min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
329
+ specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.
330
+
331
+ Returns:
332
+ dict: A dictionary which maps tokens to indices.
333
+ """
334
+
335
+ counter = Counter ()
336
+ for tokens in vocab_list :
337
+ counter .update (tokens )
338
+
339
+ # sort by descending frequency, then lexicographically
340
+ sorted_by_freq_tuples = sorted (counter .items (), key = lambda x : (- x [1 ], x [0 ]))
341
+ ordered_dict = OrderedDict (sorted_by_freq_tuples )
342
+
343
+ # add special tokens at the beginning
344
+ tokens = specials or []
345
+ for token , freq in ordered_dict .items ():
346
+ if freq >= min_vocab_freq :
347
+ tokens .append (token )
348
+
349
+ # build token to indices dict
350
+ word_dict = dict ()
351
+ for idx , token in enumerate (tokens ):
352
+ word_dict [token ] = idx
353
+ return word_dict
310
354
311
355
312
356
def load_or_build_label (datasets , label_file = None , include_test_labels = False ):
@@ -344,70 +388,84 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
344
388
return classes
345
389
346
390
347
- def get_embedding_weights_from_file (word_dict , embed_file , silent = False , cache = None ):
348
- """If the word exists in the embedding file, load the pretrained word embedding.
349
- Otherwise, assign a zero vector to that word.
391
+ def get_embedding_weights_from_file (word_dict , embed_file , silent = False , cache_dir = None ):
392
+ """Obtain the word embeddings from file. If the word exists in the embedding file,
393
+ load the pretrained word embedding. Otherwise, assign a zero vector to that word.
394
+ If the given `embed_file` is the name of a pretrained GloVe embedding, the function
395
+ will first download the corresponding file.
350
396
351
397
Args:
352
- word_dict (torchtext.vocab.Vocab ): A vocab object which maps tokens to indices.
353
- embed_file (str): Path to a file holding pre-trained embeddings.
398
+ word_dict (dict ): A dictionary for mapping tokens to indices.
399
+ embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding .
354
400
silent (bool, optional): Enable silent mode. Defaults to False.
355
- cache (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
401
+ cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
356
402
357
403
Returns:
358
404
torch.Tensor: Embedding weights (vocab_size, embed_size).
359
405
"""
360
- # Load pretrained word embedding
361
- load_embedding_from_file = embed_file not in pretrained_aliases
362
- if load_embedding_from_file :
363
- logging .info (f"Load pretrained embedding from file: { embed_file } ." )
364
- with open (embed_file ) as f :
365
- word_vectors = f .readlines ()
366
- embed_size = len (word_vectors [0 ].split ()) - 1
367
- vector_dict = {}
368
- for word_vector in tqdm (word_vectors , disable = silent ):
369
- word , vector = word_vector .rstrip ().split (" " , 1 )
370
- vector = torch .Tensor (list (map (float , vector .split ())))
371
- vector_dict [word ] = vector
372
- else :
373
- logging .info (f"Load pretrained embedding from torchtext." )
374
- # Adapted from https://pytorch.org/text/0.9.0/_modules/torchtext/vocab.html#Vocab.load_vectors.
375
- if embed_file not in pretrained_aliases :
376
- raise ValueError (
377
- "Got embed_file {}, but allowed pretrained "
378
- "vectors are {}" .format (embed_file , list (pretrained_aliases .keys ()))
379
- )
380
-
381
- # Hotfix: Glove URLs are outdated in Torchtext
382
- # (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py#L213-L217)
383
- pretrained_cls = pretrained_aliases [embed_file ]
384
- if embed_file .startswith ("glove" ):
385
- for name , url in pretrained_cls .func .url .items ():
386
- file_name = url .split ("/" )[- 1 ]
387
- pretrained_cls .func .url [name ] = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{ file_name } "
388
-
389
- vector_dict = pretrained_cls (cache = cache )
390
- embed_size = vector_dict .dim
391
406
392
- embedding_weights = torch .zeros (len (word_dict ), embed_size )
407
+ if embed_file in GLOVE_WORD_EMBEDDING :
408
+ embed_file = _download_glove_embedding (embed_file , cache_dir = cache_dir )
409
+ elif not os .path .isfile (embed_file ):
410
+ raise ValueError (
411
+ "Got embed_file {}, but allowed pretrained " "embeddings are {}" .format (embed_file , GLOVE_WORD_EMBEDDING )
412
+ )
413
+
414
+ logging .info (f"Load pretrained embedding from { embed_file } ." )
415
+ with open (embed_file ) as f :
416
+ word_vectors = f .readlines ()
417
+ embed_size = len (word_vectors [0 ].split ()) - 1
393
418
394
- if load_embedding_from_file :
395
- # Add UNK embedding
396
- # AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
397
- # CAML: np.random.randn(embed_size)
398
- unk_vector = torch .randn (embed_size )
399
- embedding_weights [word_dict [UNK ]] = unk_vector
419
+ vector_dict = {}
420
+ for word_vector in tqdm (word_vectors , disable = silent ):
421
+ word , vector = word_vector .rstrip ().split (" " , 1 )
422
+ vector = torch .Tensor (list (map (float , vector .split ())))
423
+ vector_dict [word ] = vector
424
+
425
+ embedding_weights = torch .zeros (len (word_dict ), embed_size )
426
+ # Add UNK embedding
427
+ # AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
428
+ # CAML: np.random.randn(embed_size)
429
+ unk_vector = torch .randn (embed_size )
430
+ embedding_weights [word_dict [UNK ]] = unk_vector
400
431
401
432
# Store pretrained word embedding
402
433
vec_counts = 0
403
- for word in word_dict .get_itos ():
404
- # The condition can be used to process the word that does not in the embedding file.
405
- # Note that torchtext vector object has already dealt with this,
406
- # so we can directly make a query without addtional handling.
407
- if (load_embedding_from_file and word in vector_dict ) or not load_embedding_from_file :
434
+ for word in word_dict .keys ():
435
+ if word in vector_dict :
408
436
embedding_weights [word_dict [word ]] = vector_dict [word ]
409
437
vec_counts += 1
410
438
411
- logging .info (f"loaded { vec_counts } /{ len (word_dict )} word embeddings" )
439
+ logging .info (f"Loaded { vec_counts } /{ len (word_dict )} word embeddings" )
412
440
413
441
return embedding_weights
442
+
443
+
444
+ def _download_glove_embedding (embed_name , cache_dir = None ):
445
+ """Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.
446
+
447
+ Args:
448
+ embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
449
+ cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
450
+
451
+ Returns:
452
+ str: Path to the file that contains the cached embeddings.
453
+ """
454
+ cache_dir = ".vector_cache" if cache_dir is None else cache_dir
455
+ cached_embed_file = f"{ cache_dir } /{ embed_name } .txt"
456
+ if os .path .isfile (cached_embed_file ):
457
+ return cached_embed_file
458
+ os .makedirs (cache_dir , exist_ok = True )
459
+
460
+ remote_embed_file = re .sub (r"6B.*" , "6B" , embed_name ) + ".zip"
461
+ url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{ remote_embed_file } "
462
+ logging .info (f"Downloading pretrained embeddings from { url } ." )
463
+ try :
464
+ zip_file , _ = urlretrieve (url , f"{ cache_dir } /{ remote_embed_file } " )
465
+ with zipfile .ZipFile (zip_file , "r" ) as zf :
466
+ zf .extractall (cache_dir )
467
+ except Exception as e :
468
+ os .remove (zip_file )
469
+ raise e
470
+ logging .info (f"Downloaded pretrained embeddings { embed_name } to { cached_embed_file } ." )
471
+ return cached_embed_file
0 commit comments