Skip to content

Commit ee53873

Browse files
authored
Merge pull request #134 from ntumlgroup/autotest-linear
Added linear tests and save k prediction
2 parents e8b4ac4 + 0a5a40e commit ee53873

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

linear_trainer.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
2+
import logging
23
from math import ceil
34

45
import numpy as np
56

67
import libmultilabel.linear as linear
8+
from libmultilabel.utils import dump_log, argsort_top_k
79

810

911
def linear_test(config, model, datasets):
@@ -13,12 +15,24 @@ def linear_test(config, model, datasets):
1315
datasets['test']['y'].shape[1]
1416
)
1517
num_instance = datasets['test']['x'].shape[0]
18+
19+
k = config.save_k_predictions
20+
top_k_idx = np.zeros((num_instance, k), dtype='i')
21+
top_k_scores = np.zeros((num_instance, k), dtype='d')
22+
1623
for i in range(ceil(num_instance / config.eval_batch_size)):
1724
slice = np.s_[i*config.eval_batch_size:(i+1)*config.eval_batch_size]
1825
preds = linear.predict_values(model, datasets['test']['x'][slice])
1926
target = datasets['test']['y'][slice].toarray()
2027
metrics.update(preds, target)
21-
print(linear.tabulate_metrics(metrics.compute(), 'test'))
28+
29+
if k > 0:
30+
top_k_idx[slice] = argsort_top_k(preds, k, axis=1)
31+
top_k_scores[slice] = np.take_along_axis(
32+
preds, top_k_idx[slice], axis=1)
33+
34+
metric_dict = metrics.compute()
35+
return (metric_dict, top_k_idx, top_k_scores)
2236

2337

2438
def linear_train(datasets, config):
@@ -46,5 +60,18 @@ def linear_run(config):
4660
linear.save_pipeline(config.checkpoint_dir, preprocessor, model)
4761

4862
if os.path.exists(config.test_path):
49-
linear_test(config, model, datasets)
50-
# TODO: dump logs?
63+
metric_dict, top_k_idx, top_k_scores = linear_test(
64+
config, model, datasets)
65+
66+
dump_log(config=config, metrics=metric_dict,
67+
split='test', log_path=config.log_path)
68+
print(linear.tabulate_metrics(metric_dict, 'test'))
69+
70+
if config.save_k_predictions > 0:
71+
classes = preprocessor.binarizer.classes_
72+
with open(config.predict_out_path, 'w') as fp:
73+
for idx, score in zip(top_k_idx, top_k_scores):
74+
out_str = ' '.join([f'{classes[i]}:{s:.4}' for i, s in zip(
75+
idx, score)])
76+
fp.write(out_str+'\n')
77+
logging.info(f'Saved predictions to: {config.predict_out_path}')

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def get_config():
145145
)
146146
config.checkpoint_dir = os.path.join(config.result_dir, config.run_name)
147147
config.log_path = os.path.join(config.checkpoint_dir, 'logs.json')
148+
config.predict_out_path = config.predict_out_path or os.path.join(config.checkpoint_dir, 'predictions.txt')
148149

149150
config.train_path = config.train_path or os.path.join(config.data_dir, 'train.txt')
150151
config.val_path = config.val_path or os.path.join(config.data_dir, 'valid.txt')

tests/autotest.sh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ main() {
7070
run_test "rcv1" "kim_cnn" "$template"
7171
done
7272

73+
TEST_COMMAND_TEMPLATES=(
74+
# Run default linear 1vsrest
75+
"python3 main.py --config example_config/%s/%s.yml --result_dir $RESULT_DIR"
76+
)
77+
for template in "${TEST_COMMAND_TEMPLATES[@]}"; do
78+
run_test "rcv1" "l2svm" "$template"
79+
done
80+
7381
# Print the test results and remove the intermediate files.
7482
all_tests=$(less $REPORT_PATH | wc -l)
7583
passed_tests=$(grep "PASSED" $REPORT_PATH | wc -l)
@@ -87,4 +95,4 @@ if $(echo $(pwd) | grep -q "tests"); then
8795
echo "Go to the LibMultilabel directory and run: bash tests/autotest.sh"
8896
else
8997
main
90-
fi
98+
fi

torch_trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,7 @@ def test(self, split='test'):
196196
metric_dict = self.trainer.test(self.model, test_dataloaders=test_loader)[0]
197197

198198
if self.config.save_k_predictions > 0:
199-
if not self.config.predict_out_path:
200-
predict_out_path = os.path.join(self.checkpoint_dir, 'predictions.txt')
201-
else:
202-
predict_out_path = self.config.predict_out_path
203-
self._save_predictions(test_loader, predict_out_path)
199+
self._save_predictions(test_loader, self.config.predict_out_path)
204200

205201
return metric_dict
206202

0 commit comments

Comments
 (0)