Skip to content

Commit 464284c

Browse files
committed
refactor to separate out plotting
1 parent 7d1f470 commit 464284c

8 files changed

Lines changed: 590 additions & 264 deletions

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ uv.lock
1616
coverage.xml
1717
.coverage
1818
sort-by-distance-output.csv
19+
cluster-kmeans-output.csv
20+
test-kmeans-haversine.csv
21+
test-kmeans-output.csv
22+
cluster-kmeans-centroids-output.csv

allocator/algorithms.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
Pure algorithm implementations without CLI, plotting, or file I/O.
3+
"""
4+
from __future__ import annotations
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
from .distance_matrix import get_distance_matrix
10+
11+
12+
def initialize_centroids(points: np.ndarray, k: int, random_state: int | None = None) -> np.ndarray:
13+
"""
14+
Initialize k centroids by randomly selecting from points.
15+
16+
Args:
17+
points: Input points array with shape [n, 2]
18+
k: Number of centroids
19+
random_state: Random seed for reproducibility
20+
21+
Returns:
22+
Array of k initial centroids
23+
"""
24+
if random_state is not None:
25+
rng = np.random.RandomState(random_state)
26+
rng_state = rng.get_state()
27+
np.random.set_state(rng_state)
28+
29+
centroids = points.copy()
30+
np.random.shuffle(centroids)
31+
return centroids[:k]
32+
33+
34+
def move_centroids(points: np.ndarray, closest: np.ndarray, centroids: np.ndarray) -> np.ndarray:
35+
"""
36+
Update centroids to the mean of their assigned points.
37+
38+
Args:
39+
points: All data points
40+
closest: Array indicating which centroid each point belongs to
41+
centroids: Current centroids
42+
43+
Returns:
44+
Updated centroids
45+
"""
46+
new_centroids = [points[closest == k].mean(axis=0)
47+
for k in range(centroids.shape[0])]
48+
49+
# Handle empty clusters by keeping old centroid
50+
for i, c in enumerate(new_centroids):
51+
if np.isnan(c).any():
52+
new_centroids[i] = centroids[i]
53+
54+
return np.array(new_centroids)
55+
56+
57+
def kmeans_cluster(data: pd.DataFrame | np.ndarray, n_clusters: int,
58+
distance_method: str = 'euclidean', max_iter: int = 300,
59+
random_state: int | None = None, **distance_kwargs) -> dict:
60+
"""
61+
Pure K-means clustering implementation.
62+
63+
Args:
64+
data: Input data as DataFrame with start_long/start_lat or numpy array [n, 2]
65+
n_clusters: Number of clusters
66+
distance_method: Distance calculation method
67+
max_iter: Maximum iterations
68+
random_state: Random seed
69+
**distance_kwargs: Additional arguments for distance calculation
70+
71+
Returns:
72+
Dictionary with 'labels', 'centroids', 'iterations', 'converged'
73+
"""
74+
# Convert DataFrame to numpy array if needed
75+
if isinstance(data, pd.DataFrame):
76+
if 'start_long' in data.columns and 'start_lat' in data.columns:
77+
X = data[['start_long', 'start_lat']].values
78+
else:
79+
raise ValueError("DataFrame must contain 'start_long' and 'start_lat' columns")
80+
else:
81+
X = data
82+
83+
# Initialize centroids
84+
centroids = initialize_centroids(X, n_clusters, random_state)
85+
old_centroids = centroids.copy()
86+
87+
for i in range(max_iter):
88+
# Calculate distances and assign points to closest centroids
89+
distances = get_distance_matrix(centroids, X, method=distance_method, **distance_kwargs)
90+
labels = np.argmin(distances, axis=0)
91+
92+
# Update centroids
93+
centroids = move_centroids(X, labels, centroids)
94+
95+
# Check for convergence
96+
if np.allclose(old_centroids, centroids, rtol=1e-4):
97+
return {
98+
'labels': labels,
99+
'centroids': centroids,
100+
'iterations': i + 1,
101+
'converged': True
102+
}
103+
104+
old_centroids = centroids.copy()
105+
106+
return {
107+
'labels': labels,
108+
'centroids': centroids,
109+
'iterations': max_iter,
110+
'converged': False
111+
}
112+
113+
114+
def sort_by_distance_assignment(data: pd.DataFrame | np.ndarray,
115+
centroids: np.ndarray, distance_method: str = 'euclidean',
116+
**distance_kwargs) -> np.ndarray:
117+
"""
118+
Assign points to closest centroids (used by sort_by_distance).
119+
120+
Args:
121+
data: Input data as DataFrame or numpy array
122+
centroids: Centroid locations
123+
distance_method: Distance calculation method
124+
**distance_kwargs: Additional arguments for distance calculation
125+
126+
Returns:
127+
Array of cluster assignments
128+
"""
129+
# Convert DataFrame to numpy array if needed
130+
if isinstance(data, pd.DataFrame):
131+
if 'start_long' in data.columns and 'start_lat' in data.columns:
132+
X = data[['start_long', 'start_lat']].values
133+
else:
134+
raise ValueError("DataFrame must contain 'start_long' and 'start_lat' columns")
135+
else:
136+
X = data
137+
138+
# Calculate distances and assign to closest
139+
distances = get_distance_matrix(X, centroids, method=distance_method, **distance_kwargs)
140+
labels = np.argmin(distances, axis=1)
141+
142+
return labels
143+
144+
145+
def calculate_cluster_statistics(data: pd.DataFrame, labels: np.ndarray,
146+
distance_method: str = 'euclidean',
147+
**distance_kwargs) -> list[dict]:
148+
"""
149+
Calculate statistics for each cluster (used by comparison functions).
150+
151+
Args:
152+
data: Input data with coordinates
153+
labels: Cluster assignments
154+
distance_method: Distance calculation method
155+
**distance_kwargs: Additional arguments for distance calculation
156+
157+
Returns:
158+
List of dictionaries with cluster statistics
159+
"""
160+
import networkx as nx
161+
162+
results = []
163+
X = data[['start_long', 'start_lat']].values
164+
165+
for cluster_id in sorted(np.unique(labels)):
166+
cluster_points = X[labels == cluster_id]
167+
n_points = len(cluster_points)
168+
169+
if n_points <= 1:
170+
# Skip clusters with 0 or 1 points
171+
continue
172+
173+
# Calculate distance matrix for this cluster
174+
distances = get_distance_matrix(cluster_points, cluster_points,
175+
method=distance_method, **distance_kwargs)
176+
177+
if distances is None:
178+
continue
179+
180+
# Create graph and calculate MST
181+
G = nx.from_numpy_matrix(distances)
182+
T = nx.minimum_spanning_tree(G)
183+
184+
graph_weight = int(G.size(weight='weight') / 1000)
185+
mst_weight = int(T.size(weight='weight') / 1000)
186+
187+
results.append({
188+
'label': cluster_id,
189+
'n': n_points,
190+
'graph_weight': graph_weight,
191+
'mst_weight': mst_weight
192+
})
193+
194+
return results

0 commit comments

Comments
 (0)