Skip to content

Commit a5b6222

Browse files
authored
Merge branch 'master' into prob-estimate
2 parents acdf0c6 + 213f612 commit a5b6222

File tree

4 files changed

+193
-27
lines changed

4 files changed

+193
-27
lines changed

docs/examples/plot_linear_tree_tutorial.py

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,38 @@
22
Handling Data with Many Labels Using Linear Methods
33
====================================================
44
5-
For the case that the amount of labels is very large,
6-
the training time of the standard ``train_1vsrest`` method may be unpleasantly long.
7-
The ``train_tree`` method in LibMultiLabel can vastly improve the training time on such data sets.
5+
For datasets with a very large number of labels, the training time of the standard ``train_1vsrest`` method can be prohibitively long. LibMultiLabel offers tree-based methods like ``train_tree`` and ``train_ensemble_tree`` to vastly improve training time in such scenarios.
86
9-
To illustrate this speedup, we will use the `EUR-Lex dataset <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html#EUR-Lex>`_, which contains 3,956 labels.
10-
The data in the following example is downloaded under the directory ``data/eur-lex``
117
12-
Users can use the following command to easily apply the ``train_tree`` method.
13-
14-
.. code-block:: bash
15-
16-
$ python3 main.py --training_file data/eur-lex/train.txt
17-
--test_file data/eur-lex/test.txt
18-
--linear
19-
--linear_technique tree
20-
21-
Besides CLI usage, users can also use API to apply ``train_tree`` method.
22-
Below is an example.
8+
We will use the `EUR-Lex dataset <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html#EUR-Lex>`_, which contains 3,956 labels. The data is assumed to be downloaded under the directory ``data/eur-lex``.
239
"""
2410

2511
import math
2612
import libmultilabel.linear as linear
2713
import time
2814

15+
# Load and preprocess the dataset
2916
datasets = linear.load_dataset("txt", "data/eurlex/train.txt", "data/eurlex/test.txt")
3017
preprocessor = linear.Preprocessor()
3118
datasets = preprocessor.fit_transform(datasets)
3219

3320

21+
######################################################################
22+
# Standard Training and Prediction
23+
# --------------------------------
24+
#
25+
# Users can use the following command to easily apply the ``train_tree`` method.
26+
#
27+
# .. code-block:: bash
28+
#
29+
# $ python3 main.py --training_file data/eur-lex/train.txt \\
30+
# --test_file data/eur-lex/test.txt \\
31+
# --linear \\
32+
# --linear_technique tree
33+
#
34+
# Besides CLI usage, users can also use API to apply ``train_tree`` method.
35+
# Below is an example.
36+
3437
training_start = time.time()
3538
# the standard one-vs-rest method for multi-label problems
3639
ovr_model = linear.train_1vsrest(datasets["train"]["y"], datasets["train"]["x"])
@@ -99,3 +102,81 @@ def metrics_in_batches(model):
99102
print("Score of 1vsrest:", metrics_in_batches(ovr_model))
100103
print("Score of tree:", metrics_in_batches(tree_model))
101104

105+
106+
######################################################################
107+
# Ensemble of Tree Models
108+
# -----------------------
109+
#
110+
# While the ``train_tree`` method offers a significant speedup, its accuracy can sometimes be slightly lower than the standard one-vs-rest approach.
111+
# The ``train_ensemble_tree`` method can help bridge this gap by training multiple tree models and averaging their predictions.
112+
#
113+
# Users can use the following command to easily apply the ``train_ensemble_tree`` method.
114+
# The number of trees in the ensemble can be controlled with the ``--tree_ensemble_models`` argument.
115+
#
116+
# .. code-block:: bash
117+
#
118+
# $ python3 main.py --training_file data/eur-lex/train.txt \\
119+
# --test_file data/eur-lex/test.txt \\
120+
# --linear \\
121+
# --linear_technique tree \\
122+
# --tree_ensemble_models 3
123+
#
124+
# This command trains an ensemble of 3 tree models. If ``--tree_ensemble_models`` is not specified, it defaults to 1 (a single tree).
125+
#
126+
# Besides CLI usage, users can also use the API to apply the ``train_ensemble_tree`` method.
127+
# Below is an example.
128+
129+
# We have already trained a single tree model as a baseline.
130+
# Now, let's train an ensemble of 3 tree models.
131+
training_start = time.time()
132+
ensemble_model = linear.train_ensemble_tree(
133+
datasets["train"]["y"], datasets["train"]["x"], n_trees=3
134+
)
135+
training_end = time.time()
136+
print("Training time of ensemble tree: {:10.2f}".format(training_end - training_start))
137+
138+
######################################################################
139+
# On a machine with an AMD-7950X CPU,
140+
# the ``train_ensemble_tree`` function with 3 trees took `421.15` seconds,
141+
# while the single tree took `144.37` seconds.
142+
# As expected, training an ensemble takes longer, roughly proportional to the number of trees.
143+
#
144+
# Now, let's see if this additional training time translates to better performance.
145+
# We'll compute the same P@K metrics on the test set for both the single tree and the ensemble model.
146+
147+
# `tree_preds` and `target` are already computed in the previous section.
148+
ensemble_preds = linear.predict_values(ensemble_model, datasets["test"]["x"])
149+
150+
# `tree_score` is already computed.
151+
print("Score of single tree:", tree_score)
152+
153+
ensemble_score = linear.compute_metrics(ensemble_preds, target, ["P@1", "P@3", "P@5"])
154+
print("Score of ensemble tree:", ensemble_score)
155+
156+
######################################################################
157+
# While training an ensemble takes longer, it often leads to better predictive performance.
158+
# The following table shows a comparison between a single tree and ensembles
159+
# of 3, 10, and 15 trees on several benchmark datasets.
160+
#
161+
# .. table:: Benchmark Results for Single and Ensemble Tree Models (P@K in %)
162+
#
163+
# +---------------+-----------------+-------+-------+-------+
164+
# | Dataset | Model | P@1 | P@3 | P@5 |
165+
# +===============+=================+=======+=======+=======+
166+
# | EURLex-4k | Single Tree | 82.35 | 68.98 | 57.62 |
167+
# | +-----------------+-------+-------+-------+
168+
# | | Ensemble-3 | 82.38 | 69.28 | 58.01 |
169+
# | +-----------------+-------+-------+-------+
170+
# | | Ensemble-10 | 82.74 | 69.66 | 58.39 |
171+
# | +-----------------+-------+-------+-------+
172+
# | | Ensemble-15 | 82.61 | 69.56 | 58.29 |
173+
# +---------------+-----------------+-------+-------+-------+
174+
# | EURLex-57k | Single Tree | 90.77 | 80.81 | 67.82 |
175+
# | +-----------------+-------+-------+-------+
176+
# | | Ensemble-3 | 91.02 | 81.06 | 68.26 |
177+
# | +-----------------+-------+-------+-------+
178+
# | | Ensemble-10 | 91.23 | 81.22 | 68.34 |
179+
# | +-----------------+-------+-------+-------+
180+
# | | Ensemble-15 | 91.25 | 81.31 | 68.34 |
181+
# +---------------+-----------------+-------+-------+-------+
182+

libmultilabel/linear/tree.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from . import linear
1414
from . import metrics
1515

16-
__all__ = ["train_tree", "TreeModel"]
16+
__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]
17+
18+
DEFAULT_K = 100
19+
DEFAULT_DMAX = 10
1720

1821

1922
class Node:
@@ -255,8 +258,8 @@ def train_tree(
255258
y: sparse.csr_matrix,
256259
x: sparse.csr_matrix,
257260
options: str = "",
258-
K=100,
259-
dmax=10,
261+
K=DEFAULT_K,
262+
dmax=DEFAULT_DMAX,
260263
verbose: bool = True,
261264
) -> TreeModel:
262265
"""Train a linear model for multi-label data using a divide-and-conquer strategy.
@@ -440,3 +443,70 @@ def visit(node):
440443
node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights)))
441444

442445
return model, node_ptr
446+
447+
448+
class EnsembleTreeModel:
449+
"""An ensemble of tree models.
450+
The ensemble aggregates predictions from multiple trees to improve accuracy and robustness.
451+
"""
452+
453+
def __init__(self, tree_models: list[TreeModel]):
454+
"""
455+
Args:
456+
tree_models (list[TreeModel]): A list of trained tree models.
457+
"""
458+
self.name = "ensemble-tree"
459+
self.tree_models = tree_models
460+
self.multiclass = False
461+
462+
def predict_values(self, x: sparse.csr_matrix, beam_width: int = 10) -> np.ndarray:
463+
"""Calculates the averaged probability estimates from all trees in the ensemble.
464+
465+
Args:
466+
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
467+
beam_width (int, optional): Number of candidates considered during beam search for each tree. Defaults to 10.
468+
469+
Returns:
470+
np.ndarray: A matrix with dimension number of instances * number of classes, containing averaged scores.
471+
"""
472+
all_predictions = [model.predict_values(x, beam_width) for model in self.tree_models]
473+
return np.mean(all_predictions, axis=0)
474+
475+
476+
def train_ensemble_tree(
477+
y: sparse.csr_matrix,
478+
x: sparse.csr_matrix,
479+
options: str = "",
480+
K: int = DEFAULT_K,
481+
dmax: int = DEFAULT_DMAX,
482+
n_trees: int = 3,
483+
verbose: bool = True,
484+
seed: int = None,
485+
) -> EnsembleTreeModel:
486+
"""Trains an ensemble of tree models (Parabel/Bonsai-style).
487+
Args:
488+
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
489+
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
490+
options (str, optional): The option string passed to liblinear. Defaults to ''.
491+
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
492+
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
493+
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
494+
verbose (bool, optional): Output extra progress information. Defaults to True.
495+
seed (int, optional): The base random seed for the ensemble. Defaults to None, which will use 42.
496+
497+
Returns:
498+
EnsembleTreeModel: An ensemble model which can be used for prediction.
499+
"""
500+
if seed is None:
501+
seed = 42
502+
503+
tree_models = []
504+
for i in range(n_trees):
505+
np.random.seed(seed + i)
506+
507+
tree_model = train_tree(y, x, options, K, dmax, verbose)
508+
tree_models.append(tree_model)
509+
510+
print("Ensemble training completed.")
511+
512+
return EnsembleTreeModel(tree_models)

linear_trainer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import libmultilabel.linear as linear
88
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
9+
from libmultilabel.linear.tree import EnsembleTreeModel, TreeModel, train_ensemble_tree
910
from libmultilabel.linear.utils import LINEAR_TECHNIQUES
1011

1112

@@ -21,7 +22,7 @@ def linear_test(config, model, datasets, label_mapping):
2122
scores = []
2223

2324
predict_kwargs = {}
24-
if model.name == "tree":
25+
if isinstance(model, (TreeModel, EnsembleTreeModel)):
2526
predict_kwargs["beam_width"] = config.beam_width
2627

2728
for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
@@ -48,13 +49,24 @@ def linear_train(datasets, config):
4849
if multiclass:
4950
raise ValueError("Tree model should only be used with multilabel datasets.")
5051

51-
model = LINEAR_TECHNIQUES[config.linear_technique](
52-
datasets["train"]["y"],
53-
datasets["train"]["x"],
54-
options=config.liblinear_options,
55-
K=config.tree_degree,
56-
dmax=config.tree_max_depth,
57-
)
52+
if config.tree_ensemble_models > 1:
53+
model = train_ensemble_tree(
54+
datasets["train"]["y"],
55+
datasets["train"]["x"],
56+
options=config.liblinear_options,
57+
K=config.tree_degree,
58+
dmax=config.tree_max_depth,
59+
n_trees=config.tree_ensemble_models,
60+
seed=config.seed,
61+
)
62+
else:
63+
model = LINEAR_TECHNIQUES[config.linear_technique](
64+
datasets["train"]["y"],
65+
datasets["train"]["x"],
66+
options=config.liblinear_options,
67+
K=config.tree_degree,
68+
dmax=config.tree_max_depth,
69+
)
5870
else:
5971
model = LINEAR_TECHNIQUES[config.linear_technique](
6072
datasets["train"]["y"],

main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def add_all_arguments(parser):
223223
parser.add_argument(
224224
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
225225
)
226+
parser.add_argument(
227+
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
228+
)
226229
parser.add_argument(
227230
"--beam_width",
228231
type=int,

0 commit comments

Comments
 (0)