Skip to content

Commit ac49f51

Browse files
authored
Merge pull request #22 from Eleven1Liu/free_torchtext
Remove torchtext
2 parents 0e5c4f6 + 084b63f commit ac49f51

File tree

13 files changed

+147
-90
lines changed

13 files changed

+147
-90
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ This is an on-going development so many improvements are still being made. Comme
1010

1111
## Environments
1212
- Python: 3.10+
13-
- CUDA: 11.8, 12.1 (if training neural networks by GPU)
14-
- Pytorch: 2.0.1+
13+
- CUDA: 11.8, 12.1, 12.6 (if training neural networks by GPU)
14+
- Pytorch: 2.3.0+
1515

1616
If you have a different version of CUDA, follow the installation instructions for PyTorch LTS at their [website](https://pytorch.org/).
1717

docs/cli/nn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ If a model was trained before by this package, the training procedure can start
7777

7878
To use your own word embeddings or vocabulary set, specify the following parameters:
7979

80-
- **embed_file**: choose one of the pretrained embeddings defined in `torchtext <https://pytorch.org/text/0.9.0/vocab.html#torchtext.vocab.Vocab.load_vectors>`_ or specify the path to your word embeddings with each line containing a word followed by its vectors. Example:
80+
- **embed_file**: choose one of the pretrained embeddings: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, `glove.840B.300d`, or specify the path to your word embeddings with each line containing a word followed by its vectors. Example:
8181

8282
.. code-block::
8383

docs/examples/plot_KimCNN_quickstart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
# To run KimCNN, LibMultiLabel tokenizes documents and uses an embedding vector for each word.
3333
# Thus, ``tokenize_text=True`` is set.
3434
#
35-
# We choose ``glove.6B.300d`` from torchtext as embedding vectors.
35+
# We choose ``glove.6B.300d`` as embedding vectors.
3636

3737
datasets = load_datasets("data/rcv1/train.txt", "data/rcv1/test.txt", tokenize_text=True)
3838
classes = load_or_build_label(datasets)

libmultilabel/nn/attentionxml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def reformat_text(self, dataset):
489489
# Convert words to numbers according to their indices in word_dict. Then pad each instance to a certain length.
490490
encoded_text = list(
491491
map(
492-
lambda text: torch.tensor([self.word_dict[word] for word in text], dtype=torch.int64)
492+
lambda text: torch.tensor([self.word_dict.get(word, self.word_dict[UNK]) for word in text], dtype=torch.int64)
493493
if text
494494
else torch.tensor([self.word_dict[UNK]], dtype=torch.int64),
495495
[instance["text"][: self.max_seq_length] for instance in dataset],

libmultilabel/nn/data_utils.py

Lines changed: 125 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import csv
22
import gc
33
import logging
4+
import os
5+
import re
46
import warnings
7+
import zipfile
8+
from urllib.request import urlretrieve
9+
from collections import Counter, OrderedDict
510

611
import pandas as pd
712
import torch
@@ -11,14 +16,21 @@
1116
from sklearn.preprocessing import MultiLabelBinarizer
1217
from torch.nn.utils.rnn import pad_sequence
1318
from torch.utils.data import Dataset
14-
from torchtext.vocab import build_vocab_from_iterator, pretrained_aliases, Vocab
1519
from tqdm import tqdm
1620

1721
transformers.logging.set_verbosity_error()
1822
warnings.simplefilter(action="ignore", category=FutureWarning)
1923

2024
UNK = "<unk>"
2125
PAD = "<pad>"
26+
GLOVE_WORD_EMBEDDING = {
27+
"glove.42B.300d",
28+
"glove.840B.300d",
29+
"glove.6B.50d",
30+
"glove.6B.100d",
31+
"glove.6B.200d",
32+
"glove.6B.300d",
33+
}
2234

2335

2436
class TextDataset(Dataset):
@@ -31,8 +43,7 @@ class TextDataset(Dataset):
3143
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
3244
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
3345
the transformer-based pretrained language model. Defaults to None.
34-
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
35-
map tokens to indices. Defaults to None.
46+
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
3647
"""
3748

3849
def __init__(
@@ -55,7 +66,7 @@ def __init__(
5566
self.num_classes = len(self.classes)
5667
self.label_binarizer = MultiLabelBinarizer().fit([classes])
5768

58-
if not isinstance(self.word_dict, Vocab) ^ isinstance(self.tokenizer, transformers.PreTrainedTokenizerBase):
69+
if not isinstance(self.word_dict, dict) ^ isinstance(self.tokenizer, transformers.PreTrainedTokenizerBase):
5970
raise ValueError("Please specify exactly one of word_dict or tokenizer")
6071

6172
def __len__(self):
@@ -71,7 +82,7 @@ def __getitem__(self, index):
7182
else:
7283
input_ids = self.tokenizer.encode(data["text"], add_special_tokens=False)
7384
else:
74-
input_ids = [self.word_dict[word] for word in data["text"]]
85+
input_ids = [self.word_dict.get(word, self.word_dict[UNK]) for word in data["text"]]
7586
return {
7687
"text": torch.LongTensor(input_ids[: self.max_seq_length]),
7788
"label": torch.IntTensor(self.label_binarizer.transform([data["label"]])[0]),
@@ -128,8 +139,7 @@ def get_dataset_loader(
128139
add_special_tokens (bool, optional): Whether to add the special tokens. Defaults to True.
129140
tokenizer (transformers.PreTrainedTokenizerBase, optional): HuggingFace's tokenizer of
130141
the transformer-based pretrained language model. Defaults to None.
131-
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
132-
map tokens to indices. Defaults to None.
142+
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
133143
134144
Returns:
135145
torch.utils.data.DataLoader: A pytorch DataLoader.
@@ -154,6 +164,7 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
154164
Args:
155165
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
156166
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
167+
tokenize_text (bool, optional): Whether to tokenize text. Defaults to True.
157168
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
158169
This is effective only when is_test=False. Defaults to False.
159170
@@ -265,35 +276,34 @@ def load_or_build_text_dict(
265276
):
266277
"""Build or load the vocabulary from the training dataset or the predefined `vocab_file`.
267278
The pretrained embedding can be either from a self-defined `embed_file` or from one of
268-
the vectors defined in torchtext.vocab.pretrained_aliases
269-
(https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py).
279+
the vectors: `glove.6B.50d`, `glove.6B.100d`, `glove.6B.200d`, `glove.6B.300d`, `glove.42B.300d`, or `glove.840B.300d`.
270280
271281
Args:
272282
dataset (list): List of training instances with index, label, and tokenized text.
273283
vocab_file (str, optional): Path to a file holding vocabuaries. Defaults to None.
274284
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
275-
embed_file (str): Path to a file holding pre-trained embeddings.
285+
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding. Defaults to None.
276286
embed_cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
277287
silent (bool, optional): Enable silent mode. Defaults to False.
278288
normalize_embed (bool, optional): Whether the embeddings of each word is normalized to a unit vector. Defaults to False.
279289
280290
Returns:
281-
tuple[torchtext.vocab.Vocab, torch.Tensor]: A vocab object which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
291+
tuple[dict, torch.Tensor]: A dictionary which maps tokens to indices and the pre-trained word vectors of shape (vocab_size, embed_dim).
282292
"""
283293
if vocab_file:
284294
logging.info(f"Load vocab from {vocab_file}")
285295
with open(vocab_file, "r") as fp:
286296
vocab_list = [[vocab.strip() for vocab in fp.readlines()]]
287297
# Keep PAD index 0 to align `padding_idx` of
288298
# class Embedding in libmultilabel.nn.networks.modules.
289-
vocabs = build_vocab_from_iterator(vocab_list, min_freq=1, specials=[PAD, UNK])
299+
word_dict = _build_word_dict(vocab_list, min_vocab_freq=1, specials=[PAD, UNK])
290300
else:
291301
vocab_list = [set(data["text"]) for data in dataset]
292-
vocabs = build_vocab_from_iterator(vocab_list, min_freq=min_vocab_freq, specials=[PAD, UNK])
293-
vocabs.set_default_index(vocabs[UNK])
294-
logging.info(f"Read {len(vocabs)} vocabularies.")
302+
word_dict = _build_word_dict(vocab_list, min_vocab_freq=min_vocab_freq, specials=[PAD, UNK])
303+
304+
logging.info(f"Read {len(word_dict)} vocabularies.")
295305

296-
embedding_weights = get_embedding_weights_from_file(vocabs, embed_file, silent, embed_cache_dir)
306+
embedding_weights = get_embedding_weights_from_file(word_dict, embed_file, silent, embed_cache_dir)
297307

298308
if normalize_embed:
299309
# To have better precision for calculating the normalization, we convert the original
@@ -306,7 +316,41 @@ def load_or_build_text_dict(
306316
embedding_weights[i] = vector / float(torch.linalg.norm(vector) + 1e-6)
307317
embedding_weights = embedding_weights.float()
308318

309-
return vocabs, embedding_weights
319+
return word_dict, embedding_weights
320+
321+
322+
def _build_word_dict(vocab_list, min_vocab_freq=1, specials=None):
323+
r"""Build word dictionary, modified from `torchtext.vocab.build-vocab-from-iterator`
324+
(https://docs.pytorch.org/text/stable/vocab.html#build-vocab-from-iterator)
325+
326+
Args:
327+
vocab_list: List of words.
328+
min_vocab_freq (int, optional): The minimum frequency needed to include a token in the vocabulary. Defaults to 1.
329+
specials: Special tokens (e.g., <unk>, <pad>) to add. Defaults to None.
330+
331+
Returns:
332+
dict: A dictionary which maps tokens to indices.
333+
"""
334+
335+
counter = Counter()
336+
for tokens in vocab_list:
337+
counter.update(tokens)
338+
339+
# sort by descending frequency, then lexicographically
340+
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
341+
ordered_dict = OrderedDict(sorted_by_freq_tuples)
342+
343+
# add special tokens at the beginning
344+
tokens = specials or []
345+
for token, freq in ordered_dict.items():
346+
if freq >= min_vocab_freq:
347+
tokens.append(token)
348+
349+
# build token to indices dict
350+
word_dict = dict()
351+
for idx, token in enumerate(tokens):
352+
word_dict[token] = idx
353+
return word_dict
310354

311355

312356
def load_or_build_label(datasets, label_file=None, include_test_labels=False):
@@ -344,70 +388,84 @@ def load_or_build_label(datasets, label_file=None, include_test_labels=False):
344388
return classes
345389

346390

347-
def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache=None):
348-
"""If the word exists in the embedding file, load the pretrained word embedding.
349-
Otherwise, assign a zero vector to that word.
391+
def get_embedding_weights_from_file(word_dict, embed_file, silent=False, cache_dir=None):
392+
"""Obtain the word embeddings from file. If the word exists in the embedding file,
393+
load the pretrained word embedding. Otherwise, assign a zero vector to that word.
394+
If the given `embed_file` is the name of a pretrained GloVe embedding, the function
395+
will first download the corresponding file.
350396
351397
Args:
352-
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
353-
embed_file (str): Path to a file holding pre-trained embeddings.
398+
word_dict (dict): A dictionary for mapping tokens to indices.
399+
embed_file (str): Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding.
354400
silent (bool, optional): Enable silent mode. Defaults to False.
355-
cache (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
401+
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
356402
357403
Returns:
358404
torch.Tensor: Embedding weights (vocab_size, embed_size).
359405
"""
360-
# Load pretrained word embedding
361-
load_embedding_from_file = embed_file not in pretrained_aliases
362-
if load_embedding_from_file:
363-
logging.info(f"Load pretrained embedding from file: {embed_file}.")
364-
with open(embed_file) as f:
365-
word_vectors = f.readlines()
366-
embed_size = len(word_vectors[0].split()) - 1
367-
vector_dict = {}
368-
for word_vector in tqdm(word_vectors, disable=silent):
369-
word, vector = word_vector.rstrip().split(" ", 1)
370-
vector = torch.Tensor(list(map(float, vector.split())))
371-
vector_dict[word] = vector
372-
else:
373-
logging.info(f"Load pretrained embedding from torchtext.")
374-
# Adapted from https://pytorch.org/text/0.9.0/_modules/torchtext/vocab.html#Vocab.load_vectors.
375-
if embed_file not in pretrained_aliases:
376-
raise ValueError(
377-
"Got embed_file {}, but allowed pretrained "
378-
"vectors are {}".format(embed_file, list(pretrained_aliases.keys()))
379-
)
380-
381-
# Hotfix: Glove URLs are outdated in Torchtext
382-
# (https://github.com/pytorch/text/blob/main/torchtext/vocab/vectors.py#L213-L217)
383-
pretrained_cls = pretrained_aliases[embed_file]
384-
if embed_file.startswith("glove"):
385-
for name, url in pretrained_cls.func.url.items():
386-
file_name = url.split("/")[-1]
387-
pretrained_cls.func.url[name] = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{file_name}"
388-
389-
vector_dict = pretrained_cls(cache=cache)
390-
embed_size = vector_dict.dim
391406

392-
embedding_weights = torch.zeros(len(word_dict), embed_size)
407+
if embed_file in GLOVE_WORD_EMBEDDING:
408+
embed_file = _download_glove_embedding(embed_file, cache_dir=cache_dir)
409+
elif not os.path.isfile(embed_file):
410+
raise ValueError(
411+
"Got embed_file {}, but allowed pretrained " "embeddings are {}".format(embed_file, GLOVE_WORD_EMBEDDING)
412+
)
413+
414+
logging.info(f"Load pretrained embedding from {embed_file}.")
415+
with open(embed_file) as f:
416+
word_vectors = f.readlines()
417+
embed_size = len(word_vectors[0].split()) - 1
393418

394-
if load_embedding_from_file:
395-
# Add UNK embedding
396-
# AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
397-
# CAML: np.random.randn(embed_size)
398-
unk_vector = torch.randn(embed_size)
399-
embedding_weights[word_dict[UNK]] = unk_vector
419+
vector_dict = {}
420+
for word_vector in tqdm(word_vectors, disable=silent):
421+
word, vector = word_vector.rstrip().split(" ", 1)
422+
vector = torch.Tensor(list(map(float, vector.split())))
423+
vector_dict[word] = vector
424+
425+
embedding_weights = torch.zeros(len(word_dict), embed_size)
426+
# Add UNK embedding
427+
# AttentionXML: np.random.uniform(-1.0, 1.0, embed_size)
428+
# CAML: np.random.randn(embed_size)
429+
unk_vector = torch.randn(embed_size)
430+
embedding_weights[word_dict[UNK]] = unk_vector
400431

401432
# Store pretrained word embedding
402433
vec_counts = 0
403-
for word in word_dict.get_itos():
404-
# The condition can be used to process the word that does not in the embedding file.
405-
# Note that torchtext vector object has already dealt with this,
406-
# so we can directly make a query without addtional handling.
407-
if (load_embedding_from_file and word in vector_dict) or not load_embedding_from_file:
434+
for word in word_dict.keys():
435+
if word in vector_dict:
408436
embedding_weights[word_dict[word]] = vector_dict[word]
409437
vec_counts += 1
410438

411-
logging.info(f"loaded {vec_counts}/{len(word_dict)} word embeddings")
439+
logging.info(f"Loaded {vec_counts}/{len(word_dict)} word embeddings")
412440

413441
return embedding_weights
442+
443+
444+
def _download_glove_embedding(embed_name, cache_dir=None):
445+
"""Download pretrained glove embedding from https://huggingface.co/stanfordnlp/glove/tree/main.
446+
447+
Args:
448+
embed_name (str): The name of the pretrained GloVe embedding. Defaults to None.
449+
cache_dir (str, optional): Path to a directory for storing cached embeddings. Defaults to None.
450+
451+
Returns:
452+
str: Path to the file that contains the cached embeddings.
453+
"""
454+
cache_dir = ".vector_cache" if cache_dir is None else cache_dir
455+
cached_embed_file = f"{cache_dir}/{embed_name}.txt"
456+
if os.path.isfile(cached_embed_file):
457+
return cached_embed_file
458+
os.makedirs(cache_dir, exist_ok=True)
459+
460+
remote_embed_file = re.sub(r"6B.*", "6B", embed_name) + ".zip"
461+
url = f"https://huggingface.co/stanfordnlp/glove/resolve/main/{remote_embed_file}"
462+
logging.info(f"Downloading pretrained embeddings from {url}.")
463+
try:
464+
zip_file, _ = urlretrieve(url, f"{cache_dir}/{remote_embed_file}")
465+
with zipfile.ZipFile(zip_file, "r") as zf:
466+
zf.extractall(cache_dir)
467+
except Exception as e:
468+
os.remove(zip_file)
469+
raise e
470+
logging.info(f"Downloaded pretrained embeddings {embed_name} to {cached_embed_file}.")
471+
return cached_embed_file

libmultilabel/nn/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class Model(MultiLabelModel):
181181
182182
Args:
183183
classes (list): List of class names.
184-
word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
184+
word_dict (dict): A dictionary for mapping tokens to indices.
185185
network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
186186
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
187187
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.

libmultilabel/nn/nn_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def init_model(
6161
model_name (str): Model to be used such as KimCNN.
6262
network_config (dict): Configuration for defining the network.
6363
classes (list): List of class names.
64-
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
65-
map tokens to indices. Defaults to None.
64+
word_dict (dict, optional): A dictionary for mapping tokens to indices. Defaults to None.
6665
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
6766
(vocab_size, embed_dim). Defaults to None.
6867
init_weight (str): Weight initialization method from `torch.nn.init`.

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def add_all_arguments(parser):
141141
# pretrained vocab / embeddings
142142
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
143143
parser.add_argument(
144-
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings (default: %(default)s)"
144+
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
145145
)
146146
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
147147

requirements_nn.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
nltk
22
lightning
33
# https://github.com/pytorch/text/releases
4-
torch<=2.3
4+
torch
55
torchmetrics==0.10.3
6-
torchtext
7-
# https://github.com/huggingface/transformers/issues/38464
8-
transformers<=4.51.3
6+
transformers

search_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
2525
Args:
2626
config (dict): Config of the experiment.
2727
datasets (dict): A dictionary of datasets.
28-
classes(list): List of class names.
29-
word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
28+
classes (list): List of class names.
29+
word_dict (dict): A dictionary for mapping tokens to indices.
3030
"""
3131

3232
# ray convert AttributeDict to dict

0 commit comments

Comments
 (0)