Skip to content

Commit 73e696d

Browse files
committed
change: add option to initialize CircHAC with CircKMeans
1 parent b790f7f commit 73e696d

File tree

1 file changed

+53
-50
lines changed

1 file changed

+53
-50
lines changed

pycircstat2/clustering.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ class CircHAC:
358358
----------
359359
n_clusters : int, default=2
360360
Number of clusters desired.
361+
n_init_clusters : int or None, default=None
362+
If None, every point starts as its own cluster (default HAC).
363+
If a number, `CircKMeans` is used to pre-cluster data before HAC.
361364
unit : {"radian", "degree"}, default="degree"
362365
If "degree", data is converted to radians internally.
363366
n_intervals : int, default=360
@@ -388,12 +391,14 @@ class CircHAC:
388391
def __init__(
389392
self,
390393
n_clusters=2,
394+
n_init_clusters=None,
391395
unit="degree",
392396
n_intervals=360,
393397
metric="center",
394398
random_seed=None
395399
):
396400
self.n_clusters = n_clusters
401+
self.n_init_clusters = n_init_clusters
397402
self.unit = unit
398403
self.n_intervals = n_intervals
399404
self.metric = metric
@@ -404,18 +409,35 @@ def __init__(
404409
self.labels_ = None
405410
self.merges_ = None
406411

412+
def _initialize_clusters(self, X):
413+
"""Initializes clusters using CircKMeans or default HAC."""
414+
n_samples = len(X)
415+
416+
# Default HAC: every point is its own cluster
417+
if self.n_init_clusters is None or self.n_init_clusters >= n_samples:
418+
return np.arange(n_samples), X # Standard HAC
419+
420+
# Use CircKMeans for pre-clustering
421+
kmeans = CircKMeans(n_clusters=self.n_init_clusters, unit="radian", metric=self.metric, random_seed=self.random_seed)
422+
kmeans.fit(X)
423+
424+
init_labels = kmeans.labels_
425+
init_centers = kmeans.centers_
426+
427+
return init_labels, init_centers
428+
407429
def fit(self, X):
408430
"""
409-
Perform agglomerative clustering on `alpha`.
431+
Perform agglomerative clustering on `X`.
410432
411433
Parameters
412434
----------
413-
alpha : array-like of shape (n_samples,)
414-
Angles in degrees or radians, depending on self.unit.
435+
X : np.ndarray
436+
Input angles in degrees or radians.
415437
416438
Returns
417439
-------
418-
self : AggCluster1D
440+
self : CircHAC
419441
"""
420442
self.data = X = np.asarray(X)
421443
if self.unit == "degree":
@@ -425,71 +447,53 @@ def fit(self, X):
425447

426448
n = len(alpha)
427449
if n <= self.n_clusters:
428-
# trivial case
429450
self.labels_ = np.arange(n)
430451
self.centers_ = alpha.copy()
431452
self.r_ = np.ones(n)
432-
# no merges
433453
self.merges_ = np.empty((0, 4))
434454
return self
435455

436-
# each point is its own cluster
437-
cid = np.arange(n, dtype=int)
438-
nu = n # number of active clusters
456+
# Step 1: Initialize with pre-clustering or start from scratch
457+
cluster_ids, cluster_means = self._initialize_clusters(alpha)
458+
cluster_sizes = np.ones(len(cluster_means), dtype=int)
439459

440-
merges = [] # we'll accumulate (i, j, dist, new_size) here
460+
merges = [] # Track merge history
441461

442-
while nu > self.n_clusters:
443-
# compute cluster means
444-
cluster_means = np.full(n, np.nan, dtype=float)
445-
cluster_sizes = np.zeros(n, dtype=int)
446-
for cval in np.unique(cid):
447-
subset = alpha[cid == cval]
448-
if len(subset) == 0:
449-
continue
450-
m, _ = circ_mean_and_r(subset)
451-
cluster_means[cval] = m
452-
cluster_sizes[cval] = len(subset)
462+
while len(np.unique(cluster_ids)) > self.n_clusters:
463+
# Compute cluster means
464+
unique_clusters = np.unique(cluster_ids)
465+
cluster_means_dict = {c: cluster_means[c] for c in unique_clusters}
453466

454-
# find best pair to merge
467+
# Find best pair to merge
455468
best_dist = np.inf
456469
best_i, best_j = None, None
457-
unique_ids = np.unique(cid)
458-
if len(unique_ids) <= self.n_clusters:
459-
# done
460-
break
461-
462-
for i in unique_ids:
463-
if np.isnan(cluster_means[i]):
464-
continue
465-
for j in unique_ids:
466-
if j <= i or np.isnan(cluster_means[j]):
470+
for i in unique_clusters:
471+
for j in unique_clusters:
472+
if j <= i:
467473
continue
468-
dist_ij = circ_dist(cluster_means[i], cluster_means[j], metric=self.metric)
469-
dval = float(abs(dist_ij)) # ensure it's nonnegative
470-
if dval < best_dist:
471-
best_dist = dval
472-
best_i = i
473-
best_j = j
474+
dist_ij = circ_dist(cluster_means_dict[i], cluster_means_dict[j], metric=self.metric)
475+
if dist_ij < best_dist:
476+
best_dist = dist_ij
477+
best_i, best_j = i, j
474478

475479
if best_i is None or best_j is None:
476-
# can't find a merge => break
477-
break
480+
break # No valid merge found
478481

479-
# record the merge in merges array
482+
# Record merge
480483
new_size = cluster_sizes[best_i] + cluster_sizes[best_j]
481484
merges.append([best_i, best_j, best_dist, new_size])
482485

483-
# merge best_i into best_j
484-
cid[cid == best_i] = best_j
485-
nu -= 1
486+
# Merge clusters
487+
cluster_ids[cluster_ids == best_j] = best_i
488+
cluster_sizes[best_i] = new_size
489+
cluster_means[best_i] = circ_mean_and_r(alpha[cluster_ids == best_i])[0]
486490

487-
# at this point we have at most n_clusters distinct IDs
488-
unique_ids = np.unique(cid)
491+
# Assign final cluster labels
492+
unique_ids = np.unique(cluster_ids)
489493
label_map = {old_id: new_id for new_id, old_id in enumerate(unique_ids)}
490-
self.labels_ = np.array([label_map[x] for x in cid], dtype=int)
494+
self.labels_ = np.array([label_map[c] for c in cluster_ids], dtype=int)
491495

492-
# final centers, r_
496+
# Compute final cluster centers and resultant lengths
493497
k = len(unique_ids)
494498
self.centers_ = np.zeros(k, dtype=float)
495499
self.r_ = np.zeros(k, dtype=float)
@@ -499,8 +503,7 @@ def fit(self, X):
499503
self.centers_[i] = mean_i
500504
self.r_[i] = r_i
501505

502-
# store merges array
503-
# shape: (# merges, 4)
506+
# Store merges
504507
self.merges_ = np.array(merges, dtype=object)
505508

506509
def predict(self, alpha):

0 commit comments

Comments
 (0)