@@ -358,6 +358,9 @@ class CircHAC:
358
358
----------
359
359
n_clusters : int, default=2
360
360
Number of clusters desired.
361
+ n_init_clusters : int or None, default=None
362
+ If None, every point starts as its own cluster (default HAC).
363
+ If a number, `CircKMeans` is used to pre-cluster data before HAC.
361
364
unit : {"radian", "degree"}, default="degree"
362
365
If "degree", data is converted to radians internally.
363
366
n_intervals : int, default=360
@@ -388,12 +391,14 @@ class CircHAC:
388
391
def __init__ (
389
392
self ,
390
393
n_clusters = 2 ,
394
+ n_init_clusters = None ,
391
395
unit = "degree" ,
392
396
n_intervals = 360 ,
393
397
metric = "center" ,
394
398
random_seed = None
395
399
):
396
400
self .n_clusters = n_clusters
401
+ self .n_init_clusters = n_init_clusters
397
402
self .unit = unit
398
403
self .n_intervals = n_intervals
399
404
self .metric = metric
@@ -404,18 +409,35 @@ def __init__(
404
409
self .labels_ = None
405
410
self .merges_ = None
406
411
412
+ def _initialize_clusters (self , X ):
413
+ """Initializes clusters using CircKMeans or default HAC."""
414
+ n_samples = len (X )
415
+
416
+ # Default HAC: every point is its own cluster
417
+ if self .n_init_clusters is None or self .n_init_clusters >= n_samples :
418
+ return np .arange (n_samples ), X # Standard HAC
419
+
420
+ # Use CircKMeans for pre-clustering
421
+ kmeans = CircKMeans (n_clusters = self .n_init_clusters , unit = "radian" , metric = self .metric , random_seed = self .random_seed )
422
+ kmeans .fit (X )
423
+
424
+ init_labels = kmeans .labels_
425
+ init_centers = kmeans .centers_
426
+
427
+ return init_labels , init_centers
428
+
407
429
def fit (self , X ):
408
430
"""
409
- Perform agglomerative clustering on `alpha `.
431
+ Perform agglomerative clustering on `X `.
410
432
411
433
Parameters
412
434
----------
413
- alpha : array-like of shape (n_samples,)
414
- Angles in degrees or radians, depending on self.unit .
435
+ X : np.ndarray
436
+ Input angles in degrees or radians.
415
437
416
438
Returns
417
439
-------
418
- self : AggCluster1D
440
+ self : CircHAC
419
441
"""
420
442
self .data = X = np .asarray (X )
421
443
if self .unit == "degree" :
@@ -425,71 +447,53 @@ def fit(self, X):
425
447
426
448
n = len (alpha )
427
449
if n <= self .n_clusters :
428
- # trivial case
429
450
self .labels_ = np .arange (n )
430
451
self .centers_ = alpha .copy ()
431
452
self .r_ = np .ones (n )
432
- # no merges
433
453
self .merges_ = np .empty ((0 , 4 ))
434
454
return self
435
455
436
- # each point is its own cluster
437
- cid = np . arange ( n , dtype = int )
438
- nu = n # number of active clusters
456
+ # Step 1: Initialize with pre-clustering or start from scratch
457
+ cluster_ids , cluster_means = self . _initialize_clusters ( alpha )
458
+ cluster_sizes = np . ones ( len ( cluster_means ), dtype = int )
439
459
440
- merges = [] # we'll accumulate (i, j, dist, new_size) here
460
+ merges = [] # Track merge history
441
461
442
- while nu > self .n_clusters :
443
- # compute cluster means
444
- cluster_means = np .full (n , np .nan , dtype = float )
445
- cluster_sizes = np .zeros (n , dtype = int )
446
- for cval in np .unique (cid ):
447
- subset = alpha [cid == cval ]
448
- if len (subset ) == 0 :
449
- continue
450
- m , _ = circ_mean_and_r (subset )
451
- cluster_means [cval ] = m
452
- cluster_sizes [cval ] = len (subset )
462
+ while len (np .unique (cluster_ids )) > self .n_clusters :
463
+ # Compute cluster means
464
+ unique_clusters = np .unique (cluster_ids )
465
+ cluster_means_dict = {c : cluster_means [c ] for c in unique_clusters }
453
466
454
- # find best pair to merge
467
+ # Find best pair to merge
455
468
best_dist = np .inf
456
469
best_i , best_j = None , None
457
- unique_ids = np .unique (cid )
458
- if len (unique_ids ) <= self .n_clusters :
459
- # done
460
- break
461
-
462
- for i in unique_ids :
463
- if np .isnan (cluster_means [i ]):
464
- continue
465
- for j in unique_ids :
466
- if j <= i or np .isnan (cluster_means [j ]):
470
+ for i in unique_clusters :
471
+ for j in unique_clusters :
472
+ if j <= i :
467
473
continue
468
- dist_ij = circ_dist (cluster_means [i ], cluster_means [j ], metric = self .metric )
469
- dval = float (abs (dist_ij )) # ensure it's nonnegative
470
- if dval < best_dist :
471
- best_dist = dval
472
- best_i = i
473
- best_j = j
474
+ dist_ij = circ_dist (cluster_means_dict [i ], cluster_means_dict [j ], metric = self .metric )
475
+ if dist_ij < best_dist :
476
+ best_dist = dist_ij
477
+ best_i , best_j = i , j
474
478
475
479
if best_i is None or best_j is None :
476
- # can't find a merge => break
477
- break
480
+ break # No valid merge found
478
481
479
- # record the merge in merges array
482
+ # Record merge
480
483
new_size = cluster_sizes [best_i ] + cluster_sizes [best_j ]
481
484
merges .append ([best_i , best_j , best_dist , new_size ])
482
485
483
- # merge best_i into best_j
484
- cid [cid == best_i ] = best_j
485
- nu -= 1
486
+ # Merge clusters
487
+ cluster_ids [cluster_ids == best_j ] = best_i
488
+ cluster_sizes [best_i ] = new_size
489
+ cluster_means [best_i ] = circ_mean_and_r (alpha [cluster_ids == best_i ])[0 ]
486
490
487
- # at this point we have at most n_clusters distinct IDs
488
- unique_ids = np .unique (cid )
491
+ # Assign final cluster labels
492
+ unique_ids = np .unique (cluster_ids )
489
493
label_map = {old_id : new_id for new_id , old_id in enumerate (unique_ids )}
490
- self .labels_ = np .array ([label_map [x ] for x in cid ], dtype = int )
494
+ self .labels_ = np .array ([label_map [c ] for c in cluster_ids ], dtype = int )
491
495
492
- # final centers, r_
496
+ # Compute final cluster centers and resultant lengths
493
497
k = len (unique_ids )
494
498
self .centers_ = np .zeros (k , dtype = float )
495
499
self .r_ = np .zeros (k , dtype = float )
@@ -499,8 +503,7 @@ def fit(self, X):
499
503
self .centers_ [i ] = mean_i
500
504
self .r_ [i ] = r_i
501
505
502
- # store merges array
503
- # shape: (# merges, 4)
506
+ # Store merges
504
507
self .merges_ = np .array (merges , dtype = object )
505
508
506
509
def predict (self , alpha ):
0 commit comments