Skip to content

Commit 0e5c4f6

Browse files
authored
Merge pull request #19 from khoinpd0411/build-tree-sparsekmeans
Update build_tree function with SparseKmeans implementation
2 parents 93b605f + ad4d0e6 commit 0e5c4f6

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

libmultilabel/linear/tree.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import scipy.sparse as sparse
7-
import sklearn.cluster
7+
from sparsekmeans import LloydKmeans, ElkanKmeans
88
import sklearn.preprocessing
99
from tqdm import tqdm
1010
import psutil
@@ -274,28 +274,29 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
274274
Returns:
275275
Node: Root of the (sub)tree built from label_representation.
276276
"""
277-
if d >= dmax or label_representation.shape[0] <= K:
278-
return Node(label_map=label_map, children=[])
279-
280-
metalabels = (
281-
sklearn.cluster.KMeans(
282-
K,
283-
random_state=np.random.randint(2**31 - 1),
284-
n_init=1,
285-
max_iter=300,
286-
tol=0.0001,
287-
algorithm="elkan",
277+
children = []
278+
if d < dmax and label_representation.shape[0] > K:
279+
if label_representation.shape[0] > 10000:
280+
kmeans_algo = ElkanKmeans
281+
else:
282+
kmeans_algo = LloydKmeans
283+
284+
kmeans = kmeans_algo(
285+
n_clusters=K, max_iter=300, tol=0.0001, random_state=np.random.randint(2**31 - 1), verbose=True
288286
)
289-
.fit(label_representation)
290-
.labels_
291-
)
287+
metalabels = kmeans.fit(label_representation)
292288

293-
children = []
294-
for i in range(K):
295-
child_representation = label_representation[metalabels == i]
296-
child_map = label_map[metalabels == i]
297-
child = _build_tree(child_representation, child_map, d + 1, K, dmax)
298-
children.append(child)
289+
unique_labels = np.unique(metalabels)
290+
if len(unique_labels) == K:
291+
create_child_node = lambda i: _build_tree(
292+
label_representation[metalabels == i], label_map[metalabels == i], d + 1, K, dmax
293+
)
294+
else:
295+
create_child_node = lambda i: Node(label_map=label_map[metalabels == i], children=[])
296+
297+
for i in range(K):
298+
child = create_child_node(i)
299+
children.append(child)
299300

300301
return Node(label_map=label_map, children=children)
301302

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ scikit-learn
66
scipy
77
tqdm
88
psutil
9+
sparsekmeans

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ install_requires =
3333
scipy
3434
tqdm
3535
psutil
36+
sparsekmeans
3637

3738
python_requires = >=3.10
3839

0 commit comments

Comments
 (0)