|
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | import scipy.sparse as sparse
|
7 |
| -import sklearn.cluster |
| 7 | +from sparsekmeans import LloydKmeans, ElkanKmeans |
8 | 8 | import sklearn.preprocessing
|
9 | 9 | from tqdm import tqdm
|
10 | 10 | import psutil
|
@@ -274,28 +274,29 @@ def _build_tree(label_representation: sparse.csr_matrix, label_map: np.ndarray,
|
274 | 274 | Returns:
|
275 | 275 | Node: Root of the (sub)tree built from label_representation.
|
276 | 276 | """
|
277 |
| - if d >= dmax or label_representation.shape[0] <= K: |
278 |
| - return Node(label_map=label_map, children=[]) |
279 |
| - |
280 |
| - metalabels = ( |
281 |
| - sklearn.cluster.KMeans( |
282 |
| - K, |
283 |
| - random_state=np.random.randint(2**31 - 1), |
284 |
| - n_init=1, |
285 |
| - max_iter=300, |
286 |
| - tol=0.0001, |
287 |
| - algorithm="elkan", |
| 277 | + children = [] |
| 278 | + if d < dmax and label_representation.shape[0] > K: |
| 279 | + if label_representation.shape[0] > 10000: |
| 280 | + kmeans_algo = ElkanKmeans |
| 281 | + else: |
| 282 | + kmeans_algo = LloydKmeans |
| 283 | + |
| 284 | + kmeans = kmeans_algo( |
| 285 | + n_clusters=K, max_iter=300, tol=0.0001, random_state=np.random.randint(2**31 - 1), verbose=True |
288 | 286 | )
|
289 |
| - .fit(label_representation) |
290 |
| - .labels_ |
291 |
| - ) |
| 287 | + metalabels = kmeans.fit(label_representation) |
292 | 288 |
|
293 |
| - children = [] |
294 |
| - for i in range(K): |
295 |
| - child_representation = label_representation[metalabels == i] |
296 |
| - child_map = label_map[metalabels == i] |
297 |
| - child = _build_tree(child_representation, child_map, d + 1, K, dmax) |
298 |
| - children.append(child) |
| 289 | + unique_labels = np.unique(metalabels) |
| 290 | + if len(unique_labels) == K: |
| 291 | + create_child_node = lambda i: _build_tree( |
| 292 | + label_representation[metalabels == i], label_map[metalabels == i], d + 1, K, dmax |
| 293 | + ) |
| 294 | + else: |
| 295 | + create_child_node = lambda i: Node(label_map=label_map[metalabels == i], children=[]) |
| 296 | + |
| 297 | + for i in range(K): |
| 298 | + child = create_child_node(i) |
| 299 | + children.append(child) |
299 | 300 |
|
300 | 301 | return Node(label_map=label_map, children=children)
|
301 | 302 |
|
|
0 commit comments