Skip to content

Commit 8f499f7

Browse files
authored
Merge pull request #261 from ntumlgroup/update_api
Update APIs for reading data
2 parents ccb7e2c + 98c351d commit 8f499f7

File tree

4 files changed

+28
-30
lines changed

4 files changed

+28
-30
lines changed

libmultilabel/linear/preprocessor.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_data(self, training_data: Union[str, pd.DataFrame] = None,
6565
self.include_test_labels = include_test_labels
6666

6767
if self.data_format == 'txt' or 'dataframe':
68-
data = self._load_libmultilabel(training_data, test_data, eval)
68+
data = self._load_text(training_data, test_data, eval)
6969
elif self.data_format == 'svm':
7070
data = self._load_svm(training_data, test_data, eval)
7171

@@ -86,18 +86,12 @@ def load_data(self, training_data: Union[str, pd.DataFrame] = None,
8686

8787
return data
8888

89-
def _load_libmultilabel(self, training_data, test_data, eval) -> 'dict[str, dict]':
89+
def _load_text(self, training_data, test_data, eval) -> 'dict[str, dict]':
9090
datasets = defaultdict(dict)
9191
if test_data is not None:
92-
if self.data_format == 'txt':
93-
test_data = pd.read_csv(test_data, sep='\t', header=None,
94-
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
9592
test = read_libmultilabel_format(test_data)
9693

9794
if not eval:
98-
if self.data_format == 'txt':
99-
training_data = pd.read_csv(training_data, sep='\t', header=None,
100-
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
10195
train = read_libmultilabel_format(training_data)
10296
self._generate_tfidf(train['text'])
10397

@@ -145,7 +139,18 @@ def _generate_label_mapping(self, labels, classes=None):
145139
self.binarizer.fit(labels)
146140

147141

148-
def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
142+
def read_libmultilabel_format(data: Union[str, pd.DataFrame]) -> 'dict[str,list[str]]':
143+
"""Read multi-label text data from file or pandas dataframe.
144+
145+
Args:
146+
data (Union[str, pd.DataFrame]): A file path to data in `LibMultiLabel format <https://www.csie.ntu.edu.tw/~cjlin/libmultilabel/cli/ov_data_format.html#libmultilabel-format>`_
147+
or a pandas dataframe contains index (optional), label, and text.
148+
Returns:
149+
dict[str,list[str]]: A dictionary with a list of index (optional), label, and text.
150+
"""
151+
if isinstance(data, str):
152+
data = pd.read_csv(data, sep='\t', header=None,
153+
on_bad_lines='warn', quoting=csv.QUOTE_NONE).fillna('')
149154
data = data.astype(str)
150155
if data.shape[1] == 2:
151156
data.columns = ['label', 'text']
@@ -157,6 +162,7 @@ def read_libmultilabel_format(data: pd.DataFrame) -> 'dict[str,list[str]]':
157162
data['label'] = data['label'].map(lambda s: s.split())
158163
return data.to_dict('list')
159164

165+
160166
def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]':
161167
"""Read multi-label LIBSVM-format data.
162168

libmultilabel/nn/data_utils.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data
135135
Returns:
136136
pandas.DataFrame: Data composed of index, label, and tokenized text.
137137
"""
138+
if isinstance(data, str):
139+
logging.info(f'Load data from {data}.')
140+
data = pd.read_csv(data, sep='\t', header=None,
141+
on_bad_lines='warn', quoting=csv.QUOTE_NONE).fillna('')
138142
data = data.astype(str)
139143
if data.shape[1] == 2:
140144
data.columns = ['label', 'text']
@@ -197,31 +201,19 @@ def load_datasets(
197201

198202
datasets = {}
199203
if training_data is not None:
200-
if isinstance(training_data, str):
201-
logging.info(f'Load data from {training_data}.')
202-
training_data = pd.read_csv(training_data, sep='\t', header=None,
203-
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
204-
datasets['train'] = _load_raw_data(training_data, tokenize_text=tokenize_text,
205-
remove_no_label_data=remove_no_label_data)
204+
datasets['train'] = _load_raw_data(
205+
training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)
206206

207207
if val_data is not None:
208-
if isinstance(val_data, str):
209-
logging.info(f'Load data from {val_data}.')
210-
val_data = pd.read_csv(val_data, sep='\t', header=None,
211-
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
212-
datasets['val'] = _load_raw_data(val_data, tokenize_text=tokenize_text,
213-
remove_no_label_data=remove_no_label_data)
208+
datasets['val'] = _load_raw_data(
209+
val_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)
214210
elif val_size > 0:
215211
datasets['train'], datasets['val'] = train_test_split(
216212
datasets['train'], test_size=val_size, random_state=42)
217213

218214
if test_data is not None:
219-
if isinstance(test_data, str):
220-
logging.info(f'Load data from {test_data}.')
221-
test_data = pd.read_csv(test_data, sep='\t', header=None,
222-
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
223-
datasets['test'] = _load_raw_data(test_data, is_test=True, tokenize_text=tokenize_text,
224-
remove_no_label_data=remove_no_label_data)
215+
datasets['test'] = _load_raw_data(
216+
test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data)
225217

226218
if merge_train_val:
227219
datasets['train'] = datasets['train'] + datasets['val']

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
nltk
2-
pandas
2+
pandas>1.3.0
33
PyYAML
44
scikit-learn
5-
torch>=1.12.0
5+
torch>=1.13.1
66
torchmetrics==0.10.3
77
torchtext>=0.13.0
88
pytorch-lightning==1.7.7

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers =
2626
packages = find:
2727
install_requires =
2828
nltk
29-
pandas
29+
pandas>1.3.0
3030
PyYAML
3131
scikit-learn
3232
torch>=1.13.1

0 commit comments

Comments
 (0)