Skip to content

Commit 774e4c8

Browse files
committed
Refactor imports and enhance modularity in torchsom package: visualization and hexagonal settings.
- Updated import statements across multiple modules to use absolute imports instead of relative imports, improving clarity and consistency. - Introduced a new hexagonal_coordinates module to centralize hexagonal coordinate conversion and distance calculation functions, enhancing code organization. - Refactored visualization modules to include specialized visualizers for hexagonal and rectangular topologies, promoting modular design. - Improved neighborhood functions to utilize proper Euclidean distance calculations for both rectangular and hexagonal topologies, ensuring accuracy in weight updates. - Added comprehensive documentation and inline comments to clarify the purpose and functionality of new and modified methods.
1 parent daaa6c2 commit 774e4c8

17 files changed

+1494
-804
lines changed

torchsom/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Torchsom package."""
22

3-
from .core import SOM, BaseSOM
4-
from .utils.decay import DECAY_FUNCTIONS
5-
from .utils.distances import DISTANCE_FUNCTIONS
6-
from .utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
7-
from .visualization import SOMVisualizer, VisualizationConfig
3+
from torchsom.core import SOM, BaseSOM
4+
from torchsom.utils.decay import DECAY_FUNCTIONS
5+
from torchsom.utils.distances import DISTANCE_FUNCTIONS
6+
from torchsom.utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
7+
from torchsom.visualization import SOMVisualizer, VisualizationConfig
88

99
# from .version import __version__
1010

torchsom/configs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Configuration module for torchsom."""
22

3-
from .som_config import SOMConfig
3+
from torchsom.configs.som_config import SOMConfig
44

55
__all__ = ["SOMConfig"]

torchsom/core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Core module for torchsom."""
22

3-
from .base_som import BaseSOM
4-
from .som import SOM
3+
from torchsom.core.base_som import BaseSOM
4+
from torchsom.core.som import SOM
55

66
__all__ = ["SOM", "BaseSOM"]

torchsom/core/som.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@
1111
from torch.utils.data import DataLoader, TensorDataset
1212
from tqdm import tqdm
1313

14-
from ..utils.decay import DECAY_FUNCTIONS
15-
from ..utils.distances import DISTANCE_FUNCTIONS
16-
from ..utils.grid import adjust_meshgrid_topology, create_mesh_grid
17-
from ..utils.initialization import initialize_weights
18-
from ..utils.metrics import calculate_quantization_error, calculate_topographic_error
19-
from ..utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
20-
from ..utils.topology import get_all_neighbors_up_to_order
21-
from .base_som import BaseSOM
14+
from torchsom.core.base_som import BaseSOM
15+
from torchsom.utils.decay import DECAY_FUNCTIONS
16+
from torchsom.utils.distances import DISTANCE_FUNCTIONS
17+
from torchsom.utils.grid import adjust_meshgrid_topology, create_mesh_grid
18+
from torchsom.utils.initialization import initialize_weights
19+
from torchsom.utils.metrics import (
20+
calculate_quantization_error,
21+
calculate_topographic_error,
22+
)
23+
from torchsom.utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
24+
from torchsom.utils.topology import get_all_neighbors_up_to_order
2225

2326

2427
class SOM(BaseSOM):
@@ -163,10 +166,6 @@ def _update_weights(
163166
for row, col in bmus
164167
]
165168
) # [batch_size, row_neurons, col_neurons]
166-
# ! Modification to test
167-
# # Vectorised: build a tensor of BMU coordinates and compute in one shot
168-
# coords = torch.stack([bmus[:, 0], bmus[:, 1]], dim=1).to(torch.long)
169-
# neighborhoods = self.neighborhood_fn(coords, sigma) # update neighborhood_fn to accept batched coords # [batch_size, row_neurons, col_neurons]
170169

171170
# Reshape for broadcasting
172171
neighborhoods = neighborhoods.view(batch_size, self.x, self.y, 1)
@@ -977,20 +976,45 @@ def build_classification_map(
977976
else:
978977
neighbor_labels = []
979978
row, col = bmu_pos
980-
for dx, dy in neighborhood_offsets:
981-
neighbor_row = row + dx
982-
neighbor_col = col + dy
983-
if (
984-
0 <= neighbor_row < self.x
985-
and 0 <= neighbor_col < self.y
986-
and (neighbor_row, neighbor_col) in bmus_map
987-
):
988-
neighbor_samples_indices = bmus_map[
989-
(neighbor_row, neighbor_col)
990-
]
991-
neighbor_labels.extend(
992-
target[neighbor_samples_indices].cpu().numpy()
993-
)
979+
980+
# Handle topology-specific neighborhood processing
981+
if self.topology == "hexagonal":
982+
# Use appropriate offsets based on row parity (even/odd)
983+
row_offsets = (
984+
neighborhood_offsets["even"]
985+
if row % 2 == 0
986+
else neighborhood_offsets["odd"]
987+
)
988+
for dx, dy in row_offsets:
989+
neighbor_row = row + dx
990+
neighbor_col = col + dy
991+
if (
992+
0 <= neighbor_row < self.x
993+
and 0 <= neighbor_col < self.y
994+
and (neighbor_row, neighbor_col) in bmus_map
995+
):
996+
neighbor_samples_indices = bmus_map[
997+
(neighbor_row, neighbor_col)
998+
]
999+
neighbor_labels.extend(
1000+
target[neighbor_samples_indices].cpu().numpy()
1001+
)
1002+
else:
1003+
# Rectangular topology - process all offsets directly
1004+
for dx, dy in neighborhood_offsets:
1005+
neighbor_row = row + dx
1006+
neighbor_col = col + dy
1007+
if (
1008+
0 <= neighbor_row < self.x
1009+
and 0 <= neighbor_col < self.y
1010+
and (neighbor_row, neighbor_col) in bmus_map
1011+
):
1012+
neighbor_samples_indices = bmus_map[
1013+
(neighbor_row, neighbor_col)
1014+
]
1015+
neighbor_labels.extend(
1016+
target[neighbor_samples_indices].cpu().numpy()
1017+
)
9941018

9951019
# After collecting all neighbor labels, recompute label counts with neighborhood labels.
9961020
if neighbor_labels:

torchsom/utils/__init__.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
"""Utility functions for torchsom."""
22

3-
from .decay import DECAY_FUNCTIONS
4-
from .distances import DISTANCE_FUNCTIONS
5-
from .grid import (
6-
adjust_meshgrid_topology,
7-
axial_distance,
8-
convert_to_axial_coords,
9-
create_mesh_grid,
3+
from torchsom.utils.decay import DECAY_FUNCTIONS
4+
from torchsom.utils.distances import DISTANCE_FUNCTIONS
5+
from torchsom.utils.grid import adjust_meshgrid_topology, create_mesh_grid
6+
from torchsom.utils.hexagonal_coordinates import (
7+
axial_to_offset_coords,
8+
grid_to_display_coords,
9+
hexagonal_distance_axial,
10+
hexagonal_distance_offset,
1011
offset_to_axial_coords,
1112
)
12-
from .initialization import initialize_weights, pca_init, random_init
13-
from .metrics import calculate_quantization_error, calculate_topographic_error
14-
from .neighborhood import NEIGHBORHOOD_FUNCTIONS
15-
from .topology import (
13+
from torchsom.utils.initialization import initialize_weights, pca_init, random_init
14+
from torchsom.utils.metrics import (
15+
calculate_quantization_error,
16+
calculate_topographic_error,
17+
)
18+
from torchsom.utils.neighborhood import NEIGHBORHOOD_FUNCTIONS
19+
from torchsom.utils.topology import (
1620
get_all_neighbors_up_to_order,
1721
get_hexagonal_offsets,
1822
get_rectangular_offsets,
@@ -24,9 +28,11 @@
2428
"NEIGHBORHOOD_FUNCTIONS",
2529
"create_mesh_grid",
2630
"adjust_meshgrid_topology",
27-
"convert_to_axial_coords",
2831
"offset_to_axial_coords",
29-
"axial_distance",
32+
"axial_to_offset_coords",
33+
"hexagonal_distance_axial",
34+
"hexagonal_distance_offset",
35+
"grid_to_display_coords",
3036
"initialize_weights",
3137
"random_init",
3238
"pca_init",

torchsom/utils/grid.py

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Utility functions for grid operations."""
22

3+
import math
4+
35
import torch
46

7+
# NOTE: Coordinate conversion functions moved to hexagonal_coordinates.py to eliminate duplication and provide single source of truth.
8+
59

610
def create_mesh_grid(
711
x: int,
@@ -54,78 +58,81 @@ def adjust_meshgrid_topology(
5458
adjusted_xx = xx.clone()
5559
adjusted_yy = yy.clone()
5660

57-
adjusted_xx[::2] -= 0.5 # Adjust x-coordinates for even-indexed rows
58-
adjusted_yy *= (3.0 / 2.0) / torch.sqrt(
59-
torch.tensor(3.0)
60-
) # Adjust all y-coordinates
61-
62-
return adjusted_xx, adjusted_yy # Return the modified copies
63-
64-
return xx, yy # If not hexagonal, return the original tensors
65-
66-
67-
def convert_to_axial_coords(
68-
row: int,
69-
col: int,
70-
) -> tuple[float, float]:
71-
"""Convert grid coordinates to axial coordinates for hexagonal grid.
72-
73-
Uses even-r layout where even rows are shifted left by 0.5.
74-
This matches the layout used in adjust_meshgrid_topology.
75-
76-
Args:
77-
row (int): Grid row coordinate
78-
col (int): Grid column coordinate
79-
80-
Returns:
81-
tuple[float, float]: Axial coordinates (q, r)
82-
"""
83-
q = col - 0.5 - row // 2 if row % 2 == 0 else col - row // 2
84-
r = row
85-
return q, r
86-
87-
88-
def offset_to_axial_coords(
89-
row: int,
90-
col: int,
91-
) -> tuple[float, float]: # pragma: no cover
92-
"""Convert offset coordinates to axial coordinates for hexagonal grid.
93-
94-
Alternative implementation that directly matches the mesh grid adjustment.
95-
96-
Args:
97-
row (int): Grid row coordinate
98-
col (int): Grid column coordinate
99-
100-
Returns:
101-
tuple[float, float]: Axial coordinates (q, r)
102-
"""
103-
# Direct conversion matching adjust_meshgrid_topology
104-
q = col - (row - (row & 1)) / 2
105-
r = row
106-
return q, r
107-
108-
109-
def axial_distance(
110-
q1: float,
111-
r1: float,
112-
q2: float,
113-
r2: float,
114-
) -> int:
115-
"""Calculate the distance between two hexes in axial coordinates.
116-
117-
Args:
118-
q1 (float): column of first hex
119-
r1 (float): row of first hex
120-
q2 (float): column of second hex
121-
r2 (float): row of second hex
122-
123-
Returns:
124-
int: Distance in hex steps
125-
"""
126-
# Convert axial to cube coordinates
127-
x1, y1, z1 = q1, r1, -q1 - r1
128-
x2, y2, z2 = q2, r2, -q2 - r2
129-
130-
# Manhattan distance divided by 2
131-
return int((abs(x1 - x2) + abs(y1 - y2) + abs(z1 - z2)) / 2)
61+
"""
62+
Use even-r offset coordinate system (consistent with visualization)
63+
1. Even rows (0, 2, 4...): no horizontal shift
64+
2. Odd rows (1, 3, 5...): shift right by 0.5
65+
"""
66+
adjusted_xx[1::2] += 0.5
67+
adjusted_yy *= math.sqrt(3) / 2
68+
69+
return adjusted_xx, adjusted_yy
70+
71+
return xx, yy
72+
73+
74+
# def convert_to_axial_coords(
75+
# row: int,
76+
# col: int,
77+
# ) -> tuple[float, float]:
78+
# """Convert grid coordinates to axial coordinates for hexagonal grid.
79+
80+
# Uses even-r layout where even rows are shifted left by 0.5.
81+
# This matches the layout used in adjust_meshgrid_topology.
82+
83+
# Args:
84+
# row (int): Grid row coordinate
85+
# col (int): Grid column coordinate
86+
87+
# Returns:
88+
# tuple[float, float]: Axial coordinates (q, r)
89+
# """
90+
# q = col - 0.5 - row // 2 if row % 2 == 0 else col - row // 2
91+
# r = row
92+
# return q, r
93+
94+
95+
# def offset_to_axial_coords(
96+
# row: int,
97+
# col: int,
98+
# ) -> tuple[float, float]: # pragma: no cover
99+
# """Convert offset coordinates to axial coordinates for hexagonal grid.
100+
101+
# Alternative implementation that directly matches the mesh grid adjustment.
102+
103+
# Args:
104+
# row (int): Grid row coordinate
105+
# col (int): Grid column coordinate
106+
107+
# Returns:
108+
# tuple[float, float]: Axial coordinates (q, r)
109+
# """
110+
# # Direct conversion matching adjust_meshgrid_topology
111+
# q = col - (row - (row & 1)) / 2
112+
# r = row
113+
# return q, r
114+
115+
116+
# def axial_distance(
117+
# q1: float,
118+
# r1: float,
119+
# q2: float,
120+
# r2: float,
121+
# ) -> int:
122+
# """Calculate the distance between two hexes in axial coordinates.
123+
124+
# Args:
125+
# q1 (float): column of first hex
126+
# r1 (float): row of first hex
127+
# q2 (float): column of second hex
128+
# r2 (float): row of second hex
129+
130+
# Returns:
131+
# int: Distance in hex steps
132+
# """
133+
# # Convert axial to cube coordinates
134+
# x1, y1, z1 = q1, r1, -q1 - r1
135+
# x2, y2, z2 = q2, r2, -q2 - r2
136+
137+
# # Manhattan distance divided by 2
138+
# return int((abs(x1 - x2) + abs(y1 - y2) + abs(z1 - z2)) / 2)

0 commit comments

Comments
 (0)