Coverage for src / tracekit / analyzers / patterns / clustering.py: 96%
356 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Pattern clustering by similarity.
3This module implements algorithms for clustering similar patterns/messages
4using various distance metrics and clustering approaches.
7Author: TraceKit Development Team
8"""
10from dataclasses import dataclass
11from typing import Literal
13import numpy as np
16@dataclass
17class ClusterResult:
18 """Result of pattern clustering.
20 Attributes:
21 cluster_id: Unique cluster identifier
22 patterns: List of patterns in this cluster
23 centroid: Representative pattern (centroid)
24 size: Number of patterns in cluster
25 variance: Within-cluster variance
26 common_bytes: Byte positions that are constant across all patterns
27 variable_bytes: Byte positions that vary across patterns
28 """
30 cluster_id: int
31 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]]
32 centroid: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]
33 size: int
34 variance: float
35 common_bytes: list[int]
36 variable_bytes: list[int]
38 def __post_init__(self) -> None:
39 """Validate cluster result."""
40 if self.cluster_id < 0:
41 raise ValueError("cluster_id must be non-negative")
42 if self.size < 0:
43 raise ValueError("size must be non-negative")
44 if len(self.patterns) != self.size:
45 raise ValueError("patterns length must match size")
48@dataclass
49class ClusteringResult:
50 """Complete clustering result.
52 Attributes:
53 clusters: List of ClusterResult objects
54 labels: Cluster assignment for each input pattern
55 num_clusters: Total number of clusters
56 silhouette_score: Clustering quality metric (-1 to 1, higher = better)
57 """
59 clusters: list[ClusterResult]
60 labels: np.ndarray[tuple[int], np.dtype[np.int_]]
61 num_clusters: int
62 silhouette_score: float
64 def __post_init__(self) -> None:
65 """Validate clustering result."""
66 if self.num_clusters != len(self.clusters):
67 raise ValueError("num_clusters must match clusters length")
70def cluster_by_hamming(
71 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]],
72 threshold: float = 0.2,
73 min_cluster_size: int = 2,
74) -> ClusteringResult:
75 """Cluster fixed-length patterns by Hamming distance.
77 : Hamming distance clustering
79 Groups patterns that differ by at most threshold * pattern_length bits.
80 Efficient for fixed-length binary patterns.
82 Args:
83 patterns: List of patterns (all must have same length)
84 threshold: Maximum normalized Hamming distance within cluster (0-1)
85 min_cluster_size: Minimum patterns per cluster
87 Returns:
88 ClusteringResult with cluster assignments
90 Raises:
91 ValueError: If patterns have different lengths or invalid parameters
93 Examples:
94 >>> patterns = [b"ABCD", b"ABCE", b"ABCF", b"XYZA"]
95 >>> result = cluster_by_hamming(patterns, threshold=0.3)
96 >>> assert result.num_clusters >= 1
97 """
98 if not patterns:
99 return ClusteringResult(
100 clusters=[], labels=np.array([]), num_clusters=0, silhouette_score=0.0
101 )
103 # Validate all patterns have same length
104 pattern_length = len(patterns[0])
105 for i, p in enumerate(patterns):
106 if len(p) != pattern_length:
107 raise ValueError(f"Pattern {i} has length {len(p)}, expected {pattern_length}")
109 # Convert to numpy arrays for efficient computation
110 pattern_arrays = [_to_array(p) for p in patterns]
111 n = len(pattern_arrays)
113 # Compute distance matrix
114 dist_matrix = compute_distance_matrix(patterns, metric="hamming")
116 # Perform clustering using simple threshold-based approach
117 labels = np.full(n, -1, dtype=int)
118 cluster_id = 0
120 for i in range(n):
121 if labels[i] != -1:
122 continue # Already assigned
124 # Start new cluster
125 cluster_members = [i]
126 labels[i] = cluster_id
128 # Find all patterns within threshold
129 for j in range(i + 1, n):
130 if labels[j] != -1: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 continue
133 # Check if j is close to all members of current cluster
134 max_dist = max(dist_matrix[j, m] for m in cluster_members)
135 if max_dist <= threshold:
136 cluster_members.append(j)
137 labels[j] = cluster_id
139 # Only keep cluster if large enough
140 if len(cluster_members) < min_cluster_size:
141 for m in cluster_members:
142 labels[m] = -1
143 else:
144 cluster_id += 1
146 # Assign singleton patterns to noise cluster (-1)
147 num_clusters = cluster_id
149 # Build cluster results
150 clusters = []
151 for cid in range(num_clusters):
152 cluster_indices = np.where(labels == cid)[0]
153 cluster_patterns = [patterns[i] for i in cluster_indices]
155 # Compute centroid (majority vote per byte)
156 centroid = _compute_centroid_hamming([pattern_arrays[i] for i in cluster_indices])
158 # Analyze common vs variable bytes
159 common, variable = _analyze_pattern_variance([pattern_arrays[i] for i in cluster_indices])
161 # Compute within-cluster variance
162 variance = (
163 np.mean([dist_matrix[i, j] for i in cluster_indices for j in cluster_indices if i < j])
164 if len(cluster_indices) > 1
165 else 0.0
166 )
168 clusters.append(
169 ClusterResult(
170 cluster_id=cid,
171 patterns=cluster_patterns,
172 centroid=bytes(centroid) if isinstance(patterns[0], bytes) else centroid,
173 size=len(cluster_patterns),
174 variance=float(variance),
175 common_bytes=common,
176 variable_bytes=variable,
177 )
178 )
180 # Compute silhouette score
181 silhouette = _compute_silhouette_score(dist_matrix, labels) if num_clusters > 1 else 0.0
183 return ClusteringResult(
184 clusters=clusters, labels=labels, num_clusters=num_clusters, silhouette_score=silhouette
185 )
188def cluster_by_edit_distance(
189 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]],
190 threshold: float = 0.3,
191 min_cluster_size: int = 2,
192) -> ClusteringResult:
193 """Cluster variable-length patterns by edit distance.
195 : Edit distance (Levenshtein) clustering
197 Groups patterns with normalized edit distance <= threshold.
198 Works with variable-length patterns.
200 Args:
201 patterns: List of patterns (can have different lengths)
202 threshold: Maximum normalized edit distance (0-1)
203 min_cluster_size: Minimum patterns per cluster
205 Returns:
206 ClusteringResult with cluster assignments
208 Examples:
209 >>> patterns = [b"ABCD", b"ABCDE", b"ABCDF", b"XYZ"]
210 >>> result = cluster_by_edit_distance(patterns, threshold=0.4)
211 """
212 if not patterns:
213 return ClusteringResult(
214 clusters=[], labels=np.array([]), num_clusters=0, silhouette_score=0.0
215 )
217 n = len(patterns)
219 # Compute distance matrix
220 dist_matrix = compute_distance_matrix(patterns, metric="levenshtein")
222 # Threshold-based clustering
223 labels = np.full(n, -1, dtype=int)
224 cluster_id = 0
226 for i in range(n):
227 if labels[i] != -1:
228 continue
230 # Start new cluster
231 cluster_members = [i]
232 labels[i] = cluster_id
234 # Find similar patterns
235 for j in range(i + 1, n):
236 if labels[j] != -1: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 continue
239 # Check distance to cluster members
240 max_dist = max(dist_matrix[j, m] for m in cluster_members)
241 if max_dist <= threshold:
242 cluster_members.append(j)
243 labels[j] = cluster_id
245 # Keep cluster if large enough
246 if len(cluster_members) < min_cluster_size:
247 for m in cluster_members:
248 labels[m] = -1
249 else:
250 cluster_id += 1
252 num_clusters = cluster_id
254 # Build cluster results
255 clusters = []
256 for cid in range(num_clusters):
257 cluster_indices = np.where(labels == cid)[0]
258 cluster_patterns = [patterns[i] for i in cluster_indices]
260 # Use most common pattern as centroid
261 centroid = _compute_centroid_edit(cluster_patterns)
263 # For variable-length patterns, analysis is limited
264 # Pad to common length for analysis
265 max_len = max(len(p) for p in cluster_patterns)
266 padded = [_to_array(p, target_length=max_len) for p in cluster_patterns]
267 common, variable = _analyze_pattern_variance(padded)
269 # Compute variance
270 variance = (
271 np.mean([dist_matrix[i, j] for i in cluster_indices for j in cluster_indices if i < j])
272 if len(cluster_indices) > 1
273 else 0.0
274 )
276 clusters.append(
277 ClusterResult(
278 cluster_id=cid,
279 patterns=cluster_patterns,
280 centroid=centroid,
281 size=len(cluster_patterns),
282 variance=float(variance),
283 common_bytes=common,
284 variable_bytes=variable,
285 )
286 )
288 # Compute silhouette score
289 silhouette = _compute_silhouette_score(dist_matrix, labels) if num_clusters > 1 else 0.0
291 return ClusteringResult(
292 clusters=clusters, labels=labels, num_clusters=num_clusters, silhouette_score=silhouette
293 )
296def cluster_hierarchical(
297 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]],
298 method: Literal["single", "complete", "average", "upgma"] = "upgma",
299 num_clusters: int | None = None,
300 distance_threshold: float | None = None,
301) -> ClusteringResult:
302 """Hierarchical clustering of patterns.
304 : Hierarchical clustering (UPGMA, etc.)
306 Uses agglomerative hierarchical clustering with various linkage methods.
308 Args:
309 patterns: List of patterns
310 method: Linkage method ('single', 'complete', 'average', 'upgma')
311 num_clusters: Desired number of clusters (if None, use distance_threshold)
312 distance_threshold: Distance threshold for cutting dendrogram
314 Returns:
315 ClusteringResult with cluster assignments
317 Raises:
318 ValueError: If neither num_clusters nor distance_threshold is specified
320 Examples:
321 >>> patterns = [b"AAA", b"AAB", b"BBB", b"BBC"]
322 >>> result = cluster_hierarchical(patterns, method='average', num_clusters=2)
323 """
324 if num_clusters is None and distance_threshold is None:
325 raise ValueError("Must specify either num_clusters or distance_threshold")
327 if not patterns:
328 return ClusteringResult(
329 clusters=[], labels=np.array([]), num_clusters=0, silhouette_score=0.0
330 )
332 # Normalize method name
333 if method == "upgma":
334 method = "average"
336 _n = len(patterns)
338 # Compute distance matrix
339 dist_matrix = compute_distance_matrix(patterns, metric="hamming")
341 # Perform hierarchical clustering
342 labels = _hierarchical_clustering(
343 dist_matrix, method=method, num_clusters=num_clusters, distance_threshold=distance_threshold
344 )
346 # Count actual clusters
347 unique_labels = set(labels[labels >= 0])
348 num_clusters_actual = len(unique_labels)
350 # Build cluster results
351 clusters = []
352 for cid in sorted(unique_labels):
353 cluster_indices = np.where(labels == cid)[0]
354 cluster_patterns = [patterns[i] for i in cluster_indices]
356 # Compute centroid
357 pattern_arrays = [_to_array(p) for p in cluster_patterns]
358 if len({len(p) for p in pattern_arrays}) == 1: 358 ↛ 364line 358 didn't jump to line 364 because the condition on line 358 was always true
359 # Fixed length - use majority vote
360 centroid_array = _compute_centroid_hamming(pattern_arrays)
361 centroid = bytes(centroid_array) if isinstance(patterns[0], bytes) else centroid_array
362 else:
363 # Variable length - use most common
364 centroid = _compute_centroid_edit(cluster_patterns)
366 # Analyze variance
367 max_len = max(len(p) for p in pattern_arrays)
368 padded = [_to_array(p, target_length=max_len) for p in pattern_arrays]
369 common, variable = _analyze_pattern_variance(padded)
371 # Variance
372 variance = (
373 np.mean([dist_matrix[i, j] for i in cluster_indices for j in cluster_indices if i < j])
374 if len(cluster_indices) > 1
375 else 0.0
376 )
378 clusters.append(
379 ClusterResult(
380 cluster_id=cid,
381 patterns=cluster_patterns,
382 centroid=centroid,
383 size=len(cluster_patterns),
384 variance=float(variance),
385 common_bytes=common,
386 variable_bytes=variable,
387 )
388 )
390 # Silhouette score
391 silhouette = _compute_silhouette_score(dist_matrix, labels) if num_clusters_actual > 1 else 0.0
393 return ClusteringResult(
394 clusters=clusters,
395 labels=labels,
396 num_clusters=num_clusters_actual,
397 silhouette_score=silhouette,
398 )
401def analyze_cluster(cluster: ClusterResult) -> dict[str, list[int] | list[float] | bytes]:
402 """Analyze cluster to find common vs variable regions.
404 : Cluster analysis
406 Performs detailed analysis of a cluster to identify byte positions
407 that are constant vs. those that vary.
409 Args:
410 cluster: ClusterResult to analyze
412 Returns:
413 Dictionary with analysis results including:
414 - common_bytes: List of byte positions that are constant
415 - variable_bytes: List of byte positions that vary
416 - entropy_per_byte: Entropy at each byte position
417 - consensus: Consensus pattern with variable bytes marked
419 Examples:
420 >>> # Assume we have a cluster
421 >>> analysis = analyze_cluster(cluster)
422 >>> print(f"Common positions: {analysis['common_bytes']}")
423 """
424 if cluster.size == 0:
425 return {"common_bytes": [], "variable_bytes": [], "entropy_per_byte": [], "consensus": b""}
427 # Convert patterns to arrays
428 pattern_arrays = [_to_array(p) for p in cluster.patterns]
430 # Pad to same length
431 max_len = max(len(p) for p in pattern_arrays)
432 padded = [_to_array(p, target_length=max_len) for p in pattern_arrays]
434 # Compute entropy per byte position
435 entropy_per_byte = []
436 for pos in range(max_len):
437 byte_values = [p[pos] for p in padded]
438 entropy = _compute_byte_entropy(byte_values)
439 entropy_per_byte.append(entropy)
441 # Threshold for "common" (low entropy)
442 common_threshold = 0.1
443 common_bytes = [i for i, e in enumerate(entropy_per_byte) if e < common_threshold]
444 variable_bytes = [i for i, e in enumerate(entropy_per_byte) if e >= common_threshold]
446 # Build consensus pattern
447 consensus = np.zeros(max_len, dtype=np.uint8)
448 for pos in range(max_len):
449 byte_values = [p[pos] for p in padded]
450 # Use most common byte
451 consensus[pos] = max(set(byte_values), key=byte_values.count)
453 return {
454 "common_bytes": common_bytes,
455 "variable_bytes": variable_bytes,
456 "entropy_per_byte": entropy_per_byte,
457 "consensus": bytes(consensus),
458 }
461def compute_distance_matrix(
462 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]],
463 metric: Literal["hamming", "levenshtein", "jaccard"] = "hamming",
464) -> np.ndarray[tuple[int, int], np.dtype[np.float64]]:
465 """Compute pairwise distance matrix.
467 : Distance matrix computation
469 Computes all pairwise distances between patterns using the specified metric.
471 Args:
472 patterns: List of patterns
473 metric: Distance metric ('hamming', 'levenshtein', 'jaccard')
475 Returns:
476 Symmetric distance matrix (n x n)
478 Raises:
479 ValueError: If unknown metric is specified
481 Examples:
482 >>> patterns = [b"ABC", b"ABD", b"XYZ"]
483 >>> dist = compute_distance_matrix(patterns, metric='hamming')
484 >>> assert dist.shape == (3, 3)
485 """
486 n = len(patterns)
487 dist_matrix = np.zeros((n, n), dtype=float)
489 for i in range(n):
490 for j in range(i + 1, n):
491 if metric == "hamming":
492 dist = _hamming_distance(patterns[i], patterns[j])
493 elif metric == "levenshtein":
494 dist = _edit_distance(patterns[i], patterns[j])
495 elif metric == "jaccard":
496 dist = _jaccard_distance(patterns[i], patterns[j])
497 else:
498 raise ValueError(f"Unknown metric: {metric}")
500 dist_matrix[i, j] = dist
501 dist_matrix[j, i] = dist
503 return dist_matrix
506# Helper functions
509def _to_array(
510 data: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]] | memoryview | bytearray,
511 target_length: int | None = None,
512) -> np.ndarray[tuple[int], np.dtype[np.uint8]]:
513 """Convert to numpy array, optionally padding to target length.
515 Args:
516 data: Input data (bytes, bytearray, memoryview, or numpy array)
517 target_length: If specified, pad to this length
519 Returns:
520 Numpy array of uint8
522 Raises:
523 TypeError: If data type is not supported
524 """
525 if isinstance(data, bytes):
526 arr = np.frombuffer(data, dtype=np.uint8)
527 elif isinstance(data, bytearray | memoryview):
528 arr = np.frombuffer(bytes(data), dtype=np.uint8)
529 elif isinstance(data, np.ndarray):
530 arr = data.astype(np.uint8)
531 else:
532 raise TypeError(f"Unsupported type: {type(data)}")
534 if target_length is not None and len(arr) < target_length:
535 # Pad with zeros
536 padded = np.zeros(target_length, dtype=np.uint8)
537 padded[: len(arr)] = arr
538 return padded
540 return arr
543def _hamming_distance(
544 a: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
545 b: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
546) -> float:
547 """Compute normalized Hamming distance."""
548 arr_a = _to_array(a)
549 arr_b = _to_array(b)
551 if len(arr_a) != len(arr_b):
552 # Pad shorter to match longer
553 max_len = max(len(arr_a), len(arr_b))
554 arr_a = _to_array(a, target_length=max_len)
555 arr_b = _to_array(b, target_length=max_len)
557 # Count differences
558 differences = np.sum(arr_a != arr_b)
559 return float(differences) / len(arr_a)
562def _edit_distance(
563 a: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
564 b: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
565) -> float:
566 """Compute normalized Levenshtein edit distance."""
567 bytes_a = bytes(a) if isinstance(a, np.ndarray) else a
568 bytes_b = bytes(b) if isinstance(b, np.ndarray) else b
570 m, n = len(bytes_a), len(bytes_b)
572 if m == 0 and n == 0:
573 return 0.0
574 if m == 0:
575 return 1.0
576 if n == 0:
577 return 1.0
579 # DP table
580 prev_row = list(range(n + 1))
581 curr_row = [0] * (n + 1)
583 for i in range(1, m + 1):
584 curr_row[0] = i
585 for j in range(1, n + 1):
586 if bytes_a[i - 1] == bytes_b[j - 1]:
587 curr_row[j] = prev_row[j - 1]
588 else:
589 curr_row[j] = 1 + min(prev_row[j], curr_row[j - 1], prev_row[j - 1])
590 prev_row, curr_row = curr_row, prev_row
592 # Normalize by max length
593 return prev_row[n] / max(m, n)
596def _jaccard_distance(
597 a: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
598 b: bytes | np.ndarray[tuple[int], np.dtype[np.uint8]],
599) -> float:
600 """Compute Jaccard distance based on byte sets."""
601 set_a = set(_to_array(a))
602 set_b = set(_to_array(b))
604 if len(set_a) == 0 and len(set_b) == 0:
605 return 0.0
607 intersection = len(set_a & set_b)
608 union = len(set_a | set_b)
610 if union == 0: 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true
611 return 0.0
613 # Jaccard distance = 1 - Jaccard similarity
614 return 1.0 - (intersection / union)
617def _compute_centroid_hamming(
618 patterns: list[np.ndarray[tuple[int], np.dtype[np.uint8]]],
619) -> np.ndarray[tuple[int], np.dtype[np.uint8]]:
620 """Compute centroid using majority vote (for fixed-length patterns)."""
621 if not patterns:
622 return np.array([], dtype=np.uint8)
624 _n = len(patterns)
625 length = len(patterns[0])
627 centroid = np.zeros(length, dtype=np.uint8)
628 for pos in range(length):
629 bytes_at_pos = [p[pos] for p in patterns]
630 # Most common byte
631 centroid[pos] = max(set(bytes_at_pos), key=bytes_at_pos.count)
633 return centroid
636def _compute_centroid_edit(
637 patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]],
638) -> bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]:
639 """Compute centroid for variable-length patterns (most central pattern)."""
640 if not patterns: 640 ↛ 641line 640 didn't jump to line 641 because the condition on line 640 was never true
641 return b"" if isinstance(patterns[0], bytes) else np.array([])
643 # Use most common pattern as centroid
644 from collections import Counter
646 pattern_counts = Counter(bytes(p) if isinstance(p, np.ndarray) else p for p in patterns)
647 most_common = pattern_counts.most_common(1)[0][0]
649 # Return in original type
650 if isinstance(patterns[0], bytes): 650 ↛ 653line 650 didn't jump to line 653 because the condition on line 650 was always true
651 return most_common
652 else:
653 return np.frombuffer(most_common, dtype=np.uint8)
656def _analyze_pattern_variance(
657 patterns: list[np.ndarray[tuple[int], np.dtype[np.uint8]]],
658) -> tuple[list[int], list[int]]:
659 """Analyze which byte positions are common vs variable."""
660 if not patterns or len(patterns) == 0:
661 return [], []
663 length = len(patterns[0])
664 common_bytes = []
665 variable_bytes = []
667 for pos in range(length):
668 bytes_at_pos = [p[pos] for p in patterns]
669 unique_values = len(set(bytes_at_pos))
671 if unique_values == 1:
672 common_bytes.append(pos)
673 else:
674 variable_bytes.append(pos)
676 return common_bytes, variable_bytes
679def _compute_byte_entropy(byte_values: list[int]) -> float:
680 """Compute Shannon entropy of byte values."""
681 if not byte_values:
682 return 0.0
684 from collections import Counter
686 counts = Counter(byte_values)
687 n = len(byte_values)
689 entropy = 0.0
690 for count in counts.values():
691 if count > 0: 691 ↛ 690line 691 didn't jump to line 690 because the condition on line 691 was always true
692 prob = count / n
693 entropy -= prob * np.log2(prob)
695 return entropy
698def _compute_silhouette_score(
699 dist_matrix: np.ndarray[tuple[int, int], np.dtype[np.float64]],
700 labels: np.ndarray[tuple[int], np.dtype[np.int_]],
701) -> float:
702 """Compute average silhouette score for clustering quality."""
703 n = len(labels)
704 if n <= 1:
705 return 0.0
707 # Filter out noise points (-1 labels)
708 valid_mask = labels >= 0
709 if np.sum(valid_mask) <= 1:
710 return 0.0
712 unique_labels = set(labels[valid_mask])
713 if len(unique_labels) <= 1: 713 ↛ 714line 713 didn't jump to line 714 because the condition on line 713 was never true
714 return 0.0
716 silhouette_scores = []
718 for i in range(n):
719 if labels[i] == -1:
720 continue
722 # a(i): average distance to points in same cluster
723 same_cluster = (labels == labels[i]) & (np.arange(n) != i)
724 if np.sum(same_cluster) == 0:
725 continue
727 a_i = np.mean(dist_matrix[i, same_cluster])
729 # b(i): minimum average distance to points in other clusters
730 b_i = float("inf")
731 for other_label in unique_labels:
732 if other_label == labels[i]:
733 continue
735 other_cluster = labels == other_label
736 if np.sum(other_cluster) > 0: 736 ↛ 731line 736 didn't jump to line 731 because the condition on line 736 was always true
737 avg_dist = np.mean(dist_matrix[i, other_cluster])
738 b_i = min(b_i, avg_dist)
740 # Silhouette coefficient
741 if b_i == float("inf"): 741 ↛ 742line 741 didn't jump to line 742 because the condition on line 741 was never true
742 s_i = 0.0
743 else:
744 s_i = (b_i - a_i) / max(a_i, b_i)
746 silhouette_scores.append(s_i)
748 return float(np.mean(silhouette_scores)) if silhouette_scores else 0.0
751def _hierarchical_clustering(
752 dist_matrix: np.ndarray[tuple[int, int], np.dtype[np.float64]],
753 method: str,
754 num_clusters: int | None,
755 distance_threshold: float | None,
756) -> np.ndarray[tuple[int], np.dtype[np.int_]]:
757 """Perform agglomerative hierarchical clustering."""
758 n = dist_matrix.shape[0]
760 # Initialize: each point is its own cluster
761 clusters = [[i] for i in range(n)]
762 _cluster_distances = dist_matrix.copy()
764 # Merge until desired number of clusters
765 while len(clusters) > 1:
766 if num_clusters is not None and len(clusters) <= num_clusters:
767 break
769 # Find closest pair of clusters
770 min_dist = float("inf")
771 merge_i, merge_j = -1, -1
773 for i in range(len(clusters)):
774 for j in range(i + 1, len(clusters)):
775 # Compute inter-cluster distance
776 dist = _linkage_distance(clusters[i], clusters[j], dist_matrix, method)
778 if dist < min_dist:
779 min_dist = dist
780 merge_i, merge_j = i, j
782 # Check distance threshold
783 if distance_threshold is not None and min_dist > distance_threshold:
784 break
786 # Merge clusters
787 if merge_i >= 0 and merge_j >= 0: 787 ↛ 765line 787 didn't jump to line 765 because the condition on line 787 was always true
788 clusters[merge_i].extend(clusters[merge_j])
789 del clusters[merge_j]
791 # Assign labels
792 labels = np.full(n, -1, dtype=int)
793 for cid, cluster in enumerate(clusters):
794 for idx in cluster:
795 labels[idx] = cid
797 return labels
800def _linkage_distance(
801 cluster_a: list[int],
802 cluster_b: list[int],
803 dist_matrix: np.ndarray[tuple[int, int], np.dtype[np.float64]],
804 method: str,
805) -> float:
806 """Compute distance between two clusters using linkage method."""
807 distances = [dist_matrix[i, j] for i in cluster_a for j in cluster_b]
809 if not distances: 809 ↛ 810line 809 didn't jump to line 810 because the condition on line 809 was never true
810 return 0.0
812 if method == "single":
813 return float(min(distances))
814 elif method == "complete":
815 return float(max(distances))
816 elif method == "average": 816 ↛ 819line 816 didn't jump to line 819 because the condition on line 816 was always true
817 return float(np.mean(distances))
818 else:
819 return float(np.mean(distances)) # Default to average
822class PatternClusterer:
823 """Object-oriented wrapper for pattern clustering functionality.
825 Provides a class-based interface for clustering operations,
826 wrapping the functional API for consistency with test expectations.
830 Example:
831 >>> clusterer = PatternClusterer(n_clusters=3)
832 >>> labels = clusterer.cluster(messages)
833 """
835 def __init__(
836 self,
837 n_clusters: int = 3,
838 method: Literal["hamming", "edit", "hierarchical"] = "hamming",
839 distance_metric: Literal["hamming", "levenshtein", "jaccard"] = "hamming",
840 threshold: float = 0.3,
841 min_cluster_size: int = 2,
842 ):
843 """Initialize pattern clusterer.
845 Args:
846 n_clusters: Desired number of clusters.
847 method: Clustering method ('hamming', 'edit', or 'hierarchical').
848 distance_metric: Distance metric to use.
849 threshold: Distance threshold for clustering.
850 min_cluster_size: Minimum patterns per cluster.
851 """
852 self.n_clusters = n_clusters
853 self.method = method
854 self.distance_metric = distance_metric
855 self.threshold = threshold
856 self.min_cluster_size = min_cluster_size
857 self.result_: ClusteringResult | None = None
859 def cluster(
860 self, patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]]
861 ) -> np.ndarray[tuple[int], np.dtype[np.int_]]:
862 """Cluster patterns and return labels.
864 Args:
865 patterns: List of patterns to cluster.
867 Returns:
868 Array of cluster labels (one per pattern).
870 Example:
871 >>> clusterer = PatternClusterer(n_clusters=3)
872 >>> labels = clusterer.cluster(messages)
873 """
874 if self.method == "hamming":
875 self.result_ = cluster_by_hamming(
876 patterns, threshold=self.threshold, min_cluster_size=self.min_cluster_size
877 )
878 elif self.method == "edit":
879 self.result_ = cluster_by_edit_distance(
880 patterns, threshold=self.threshold, min_cluster_size=self.min_cluster_size
881 )
882 else: # hierarchical or default
883 self.result_ = cluster_hierarchical(
884 patterns, method="average", num_clusters=self.n_clusters
885 )
887 return self.result_.labels
889 def fit(
890 self, patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]]
891 ) -> "PatternClusterer":
892 """Fit the clusterer to patterns (sklearn-style interface).
894 Args:
895 patterns: List of patterns to cluster.
897 Returns:
898 Self (for method chaining).
899 """
900 self.cluster(patterns)
901 return self
903 def fit_predict(
904 self, patterns: list[bytes | np.ndarray[tuple[int], np.dtype[np.uint8]]]
905 ) -> np.ndarray[tuple[int], np.dtype[np.int_]]:
906 """Fit and return cluster labels (sklearn-style interface).
908 Args:
909 patterns: List of patterns to cluster.
911 Returns:
912 Array of cluster labels.
913 """
914 return self.cluster(patterns)
916 def get_clusters(self) -> list[ClusterResult]:
917 """Get detailed cluster results.
919 Returns:
920 List of ClusterResult objects with full cluster analysis.
922 Raises:
923 ValueError: If cluster() hasn't been called yet.
924 """
925 if self.result_ is None:
926 raise ValueError("Must call cluster() before get_clusters()")
927 return self.result_.clusters
929 def get_silhouette_score(self) -> float:
930 """Get silhouette score for clustering quality.
932 Returns:
933 Silhouette score (-1 to 1, higher is better).
935 Raises:
936 ValueError: If cluster() hasn't been called yet.
937 """
938 if self.result_ is None:
939 raise ValueError("Must call cluster() before get_silhouette_score()")
940 return self.result_.silhouette_score