amachine.am_msp
1from __future__ import annotations 2 3import copy 4from fractions import Fraction 5import gc 6import warnings 7from collections import deque 8 9import numpy as np 10import scipy.sparse as sp_sparse 11import scipy.sparse.linalg as sp_linalg 12 13from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu 14 15# --------- Conditionally use cupy & cupyx if available else numpy and scipy 16 17try : 18 import cupy as cp 19 CUPY_AVAILABLE = True 20except ImportError : 21 CUPY_AVAILABLE = False 22 23if CUPY_AVAILABLE : 24 import cupyx.scipy.sparse as cpx_sparse 25 import cupyx.scipy.sparse.linalg as cpx_linalg 26 from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence_device 27else : 28 cp = None 29 cpx_sparse = None 30 cpx_linalg = None 31 32def to_cpu( array ): 33 return array.get() if CUPY_AVAILABLE and hasattr( array, 'get' ) else array 34 35# --------- Conditionally use cagra if available else hnswlib 36 37if CUPY_AVAILABLE : 38 try : 39 from cuvs.neighbors import cagra 40 CAGRA_AVAILABLE = True 41 except ImportError : 42 CAGRA_AVAILABLE = False 43 44if not CAGRA_AVAILABLE : 45 import hnswlib 46else : 47 hnswlib = None 48 49# ------------------------------------------------------------- 50 51from .am_causal_state import CausalState 52from .am_transition import Transition 53 54class MSP: 55 """ 56 Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity. 57 58 Computes E (excess entropy), S (synchronization information), and T (transient information) 59 using spectral decomposition of the fundamental matrix. 60 61 See 62 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 63 https://arxiv.org/abs/1309.3792 64 """ 65 66 def __init__( 67 self, 68 states, 69 mixed_state_vectors, 70 transitions, 71 alphabet, 72 start_state_idx: int = 0, 73 gmres_rtol: float = 1e-10, 74 gmres_maxiter: int = 1000, 75 eps: float = 1e-12, 76 ): 77 self.states = states 78 self.mixed_state_vectors = np.array(mixed_state_vectors) 79 self.transitions = transitions 80 self.alphabet = alphabet 81 self.start_state_idx = start_state_idx 82 self.gmres_rtol = gmres_rtol 83 self.gmres_maxiter = gmres_maxiter 84 self.EPS = eps 85 self._cache = {} 86 87 M = len(states) 88 if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M: 89 raise ValueError( 90 f"mixed_state_vectors must have shape (M={M}, N_causal), " 91 f"got {self.mixed_state_vectors.shape}." 92 ) 93 if not (0 <= start_state_idx < M): 94 raise ValueError( 95 f"start_state_idx={start_state_idx} is out of range for M={M} states." 96 ) 97 if gmres_rtol <= 0: 98 raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.") 99 if eps <= 0: 100 raise ValueError(f"eps must be positive, got {eps}.") 101 102 def _build_W(self) -> sp_sparse.csr_matrix: 103 """ 104 Build and return the sparse row-stochastic transition matrix W (M × M). 105 Rows are re-normalised to absorb floating-point drift. 106 """ 107 if "W" in self._cache: 108 return self._cache["W"] 109 110 M = len(self.states) 111 112 rows = [tr.origin_state_idx for tr in self.transitions] 113 cols = [tr.target_state_idx for tr in self.transitions] 114 data = [tr.prob for tr in self.transitions] 115 116 W = sp_sparse.csr_matrix((data, (rows, cols)), shape=(M, M)) 117 118 row_sums = np.array(W.sum(axis=1)).ravel() 119 120 if np.any(row_sums < self.EPS): 121 bad = np.where(row_sums < self.EPS)[0].tolist() 122 raise ValueError( 123 f"States at indices {bad} have zero total outgoing probability. " 124 "Check that every MSP state has at least one outgoing transition." 125 ) 126 127 W = (sp_sparse.diags(1.0 / row_sums) @ W).tocsr() 128 129 self._cache["W"] = W 130 return W 131 132 def _compute_pi_W(self, W) -> np.ndarray: 133 """ 134 Solve for the stationary distribution of W via GPU GMRES, then free 135 all intermediate GPU arrays before returning the CPU result. 136 """ 137 if "pi_W" in self._cache: 138 return self._cache["pi_W"] 139 140 xp = cp if CUPY_AVAILABLE else np 141 142 xp_sparse = cpx_sparse if CUPY_AVAILABLE else sp_sparse 143 xp_linalg = cpx_linalg if CUPY_AVAILABLE else sp_linalg 144 145 W_device = cpx_sparse.csr_matrix(W) 146 M = W_device.shape[0] 147 148 I_device = xp_sparse.eye(M, format="csr", dtype=W_device.dtype) 149 A_device = (W_device.T - I_device).tocsr() 150 151 ones_row = xp_sparse.csr_matrix(xp.ones((1, M), dtype=W_device.dtype)) 152 A_aug = xp_sparse.vstack([A_device[:-1, :], ones_row]).tocsr() 153 154 b_device = xp.zeros(M, dtype=W_device.dtype) 155 b_device[-1] = 1.0 156 157 diag_vals = A_aug.diagonal() 158 diag_vals[diag_vals == 0] = 1.0 159 M_inv = xp_sparse.diags(1.0 / diag_vals) 160 161 pi_W_device, info = xp_linalg.gmres( 162 A_aug, b_device, 163 M=M_inv, 164 restart=100, 165 maxiter=1000, 166 ) 167 168 if info > 0: 169 print(f"Warning: GMRES did not converge after {info} iterations") 170 elif info < 0: 171 raise ValueError("GMRES failed due to illegal input or breakdown.") 172 173 pi_W_device = xp.clip(pi_W_device, 0.0, None) 174 s = pi_W_device.sum() 175 176 if s < self.EPS: 177 raise ValueError("Stationary distribution collapsed.") 178 179 pi_W_device /= s 180 181 # Download before freeing 182 result = to_cpu( pi_W_device ) 183 184 # -- Free every GPU object created in this method ---------------------- 185 del W_device, I_device, A_device, ones_row, A_aug, b_device, M_inv, diag_vals, pi_W_device 186 187 if CUPY_AVAILABLE : 188 cp.get_default_memory_pool().free_all_blocks() 189 cp.get_default_pinned_memory_pool().free_all_blocks() 190 191 gc.collect() 192 193 self._cache["pi_W"] = result 194 195 return result 196 197 def _fundamental( 198 self, 199 W: sp_sparse.csr_matrix, 200 pi_W: np.ndarray, 201 ) -> tuple[sp_linalg.LinearOperator, np.ndarray, np.ndarray, LinearOperator]: 202 """ 203 Compute the fundamental matrix 204 """ 205 if "fundamental" in self._cache: 206 return self._cache["fundamental"] 207 208 M = W.shape[0] 209 WT = W.T.tocsc() 210 211 e0 = np.zeros(M) 212 e0[self.start_state_idx] = 1.0 213 214 def Q_T_matvec(v: np.ndarray) -> np.ndarray: 215 return v - WT @ v + pi_W * v.sum() 216 217 Q_LO = sp_linalg.LinearOperator((M, M), matvec=Q_T_matvec, dtype=float) 218 219 W_diag = np.array(W.diagonal()) 220 M_diag = 1.0 - W_diag + pi_W 221 M_diag = np.where(np.abs(M_diag) > self.EPS, M_diag, 1.0) 222 precond = sp_linalg.LinearOperator((M, M), matvec=lambda v: v / M_diag, dtype=float) 223 224 Z_row, exit_code = sp_linalg.lgmres( 225 Q_LO, 226 e0, 227 M=precond, 228 rtol=self.gmres_rtol, 229 atol=self.gmres_rtol, 230 maxiter=self.gmres_maxiter, 231 ) 232 if exit_code != 0: 233 raise RuntimeError( 234 f"GMRES failed on the fundamental matrix solve (exit code {exit_code}). " 235 "Try increasing gmres_maxiter, loosening gmres_rtol, or verify that " 236 "the transition matrix is irreducible." 237 ) 238 239 fund_row = Z_row - pi_W 240 241 self._cache["fundamental"] = (Q_LO, Z_row, fund_row, precond) 242 return Q_LO, Z_row, fund_row, precond 243 244 def _get_H_WA(self) -> np.ndarray: 245 """ 246 Per-state Shannon entropy of the output symbol distribution, 247 marginalised over target states. Rows are normalised before use. 248 """ 249 if "H_WA" in self._cache: 250 return self._cache["H_WA"] 251 252 M, A = len(self.states), len(self.alphabet) 253 state_sym = np.zeros((M, A)) 254 for tr in self.transitions: 255 state_sym[tr.origin_state_idx, tr.symbol_idx] += tr.prob 256 257 sym_row_sums = state_sym.sum(axis=1, keepdims=True) 258 sym_row_sums = np.where(sym_row_sums > self.EPS, sym_row_sums, 1.0) 259 state_sym /= sym_row_sums 260 261 safe = state_sym > self.EPS 262 log_vals = np.where(safe, np.log2(np.where(safe, state_sym, 1.0)), 0.0) 263 H_WA = -np.sum(state_sym * log_vals, axis=1) 264 265 self._cache["H_WA"] = H_WA 266 return H_WA 267 268 def _get_H_eta(self) -> np.ndarray: 269 """ 270 Per-state Shannon entropy of the mixed-state (belief) vector. 271 Rows are normalised before entropy computation. 272 """ 273 if "H_eta" in self._cache: 274 return self._cache["H_eta"] 275 276 msv = self.mixed_state_vectors.copy().astype(float) 277 row_sums = msv.sum(axis=1, keepdims=True) 278 row_sums = np.where(row_sums > self.EPS, row_sums, 1.0) 279 msv /= row_sums 280 281 safe = msv > self.EPS 282 log_vals = np.where(safe, np.log2(np.where(safe, msv, 1.0)), 0.0) 283 H_eta = -np.sum(msv * log_vals, axis=1) 284 285 self._cache["H_eta"] = H_eta 286 return H_eta 287 288 def get_E_S_T(self) -> tuple[float, float, float]: 289 290 try: 291 W = self._build_W() 292 pi_W = self._compute_pi_W(W) # GPU memory freed inside _compute_pi_W 293 H_WA = self._get_H_WA() 294 H_eta = self._get_H_eta() 295 296 Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W) 297 298 E = float(fund_row @ H_WA) 299 S = float(fund_row @ H_eta) 300 301 ZZ_row, exit_code = sp_linalg.lgmres( 302 Q_LO, 303 Z_row, 304 x0=Z_row, 305 M=precond, 306 rtol=self.gmres_rtol, 307 atol=self.gmres_rtol, 308 maxiter=self.gmres_maxiter, 309 ) 310 if exit_code != 0: 311 raise RuntimeError( 312 f"GMRES failed on the Z² solve for T (exit code {exit_code}). " 313 "Try increasing gmres_maxiter or loosening gmres_rtol." 314 ) 315 316 t_row = ZZ_row - pi_W 317 T = float(t_row @ H_WA) 318 319 return E, S, T 320 321 finally: 322 323 if CUPY_AVAILABLE : 324 cp.get_default_memory_pool().free_all_blocks() 325 cp.get_default_pinned_memory_pool().free_all_blocks() 326 327 gc.collect() 328 329 def clear_cache(self) -> None: 330 self._cache.clear() 331 332 333def compute_msp_exact( 334 T_x : list[ list[ list[ Fraction ] ] ], 335 pi : tuple[Fraction], 336 n_states : int, 337 alphabet : list[str], 338 exact_state_cap: int = 1000, 339 verbose: bool = True, 340): 341 342 n_symbols = len(alphabet) 343 344 def _apply(mu: tuple[Fraction, ...], x: int) : 345 346 Tx = T_x[x] 347 mu_Tx = [Fraction(0)] * n_states 348 for i, w in enumerate(mu): 349 if w == Fraction(0): 350 continue 351 row = Tx[i] 352 for j in range(n_states): 353 if row[j]: 354 mu_Tx[j] += w * row[j] 355 356 prob = sum(mu_Tx) 357 if prob == Fraction(0): 358 return None, Fraction(0) 359 360 mu_next = tuple(v / prob for v in mu_Tx) 361 return mu_next, prob 362 363 belief_vectors: list[tuple[Fraction, ...]] = [] 364 seen_states: dict[tuple, int] = {} 365 msp_transitions: list = [] 366 367 def _register(mu: tuple[Fraction, ...]) -> int: 368 idx = len(belief_vectors) 369 belief_vectors.append(mu) 370 seen_states[mu] = idx 371 return idx 372 373 start_mu = tuple(pi) 374 _register(start_mu) 375 frontier = deque([0]) 376 377 while frontier and len(belief_vectors) < exact_state_cap: 378 379 idx = frontier.popleft() 380 mu = belief_vectors[idx] 381 382 for x in range(n_symbols): 383 384 mu_next, prob = _apply(mu, x) 385 386 if mu_next is None: 387 continue 388 389 if mu_next not in seen_states: 390 next_idx = _register(mu_next) 391 frontier.append(next_idx) 392 else: 393 next_idx = seen_states[mu_next] 394 395 msp_transitions.append( 396 Transition( 397 origin_state_idx = idx, 398 target_state_idx = next_idx, 399 prob = float(prob), 400 symbol_idx = x, 401 pq = prob 402 ) 403 ) 404 405 if frontier: 406 raise RuntimeError( f"Exact state cap {exact_state_cap} exceded." ) 407 408 if verbose: 409 print(f"Exact MSP found: {len(belief_vectors)} states.") 410 411 msp = MSP( 412 states = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)], 413 mixed_state_vectors = mixed_state_vectors, 414 transitions = msp_transitions, 415 alphabet = alphabet.copy(), 416 ) 417 418 return msp 419 420def compute_msp( 421 T_x : list[np.ndarray], 422 pi : np.ndarray, 423 n_states : int, 424 alphabet : list[str], 425 exact_state_cap: int = 175_000, 426 jsd_eps: float = 1e-7, 427 k_ann: int = 50, 428 EPS : float = 1e-12, 429 verbose = True, 430) : 431 432 xp = cp if CUPY_AVAILABLE else np 433 434 n_input_states = T_x[0].shape[0] 435 n_symbols = len(alphabet) 436 437 T_flat_device = xp.asarray( 438 np.stack(T_x).transpose(1, 0, 2).reshape(n_input_states, n_symbols * n_input_states) 439 ) 440 441 # -- Sparse Belief Vectors ------------------------------------------- 442 443 sparse_belief_vector_indices: list[np.ndarray] = [] 444 sparse_belief_vector_values: list[np.ndarray] = [] 445 446 def _to_sparse(v: np.ndarray): 447 448 non_zero = np.where( np.abs( v ) > 1e-20 )[ 0 ] 449 return non_zero.astype(np.int32), v[ non_zero ] 450 451 def _to_dense( sparse_indices: np.ndarray, sparse_values: np.ndarray ): 452 dense = np.zeros( n_input_states ) 453 dense[ sparse_indices ] = sparse_values 454 return dense 455 456 def _batch_to_dense( indicies: list[int] ) -> np.ndarray: 457 458 out = np.zeros((len(indicies), n_input_states), dtype=np.float64) 459 for row, i in enumerate(indicies): 460 out[row, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i] 461 462 return out 463 464 # -- Transitions ---------------------------------------------- 465 466 t_origin: list[int] = [] 467 t_target: list[int] = [] 468 t_prob: list[float] = [] 469 t_symbol: list[int] = [] 470 471 # -- Sparse dual bucket coarse hash ---------------------------------- 472 473 # Since we use dual bins 474 # Some x, y in p, q needs to differ by more than 1/(2*coarse_scale) to not share a bin 475 # Thus minimal total variation between non-bin sharing distributions is 476 # To ensure jsd(p,q) < jsd_eps -> p,q share a bin, we use Pinsker's inequality 477 # 1 /(2*coarse_scale) <= 2*sqrt(2*jsd_eps) 478 # 1 /( 4*sqrt(2*jsd_eps) ) <= coarse_scale 479 480 coarse_scale = 1 / ( 4*np.sqrt( 2 * jsd_eps ) ) 481 482 def _coarse_keys( sparse_indices: np.ndarray, sparse_values: np.ndarray) -> list[bytes]: 483 scaled = _to_dense( sparse_indices, sparse_values*coarse_scale ) 484 return [ 485 np.round(scaled).astype(np.int32).tobytes(), 486 np.round(scaled + 0.5).astype(np.int32).tobytes() 487 ] 488 489 buckets: dict[bytes, list[int]] = {} 490 491 # Jenson Shannon Divergence for sparse vectors 492 def _jsd_sparse( 493 ai: np.ndarray, av: np.ndarray, 494 bi: np.ndarray, bv: np.ndarray, 495 ) -> float: 496 497 # get sorted union of the [non-zero] indices over both vectors 498 all_i = np.union1d( ai, bi ) 499 500 # form dense vectors 501 p = np.zeros( len( all_i ), dtype=np.float64 ) 502 q = np.zeros( len( all_i ), dtype=np.float64 ) 503 504 # np.searchsorted( all_i, ai ) finds position in all_i where each ai is 505 p[ np.searchsorted( all_i, ai ) ] = av 506 q[ np.searchsorted( all_i, bi ) ] = bv 507 508 # Jenson Shannon divergence (JDS) 509 return af_jensenshannondivergence_cpu( p, q ) 510 511 # Check if an existing belief vector is within jsd_eps 512 def _lookup( sparse_indices: np.ndarray, sparse_values: np.ndarray ) -> int: 513 514 keys = _coarse_keys( sparse_indices, sparse_values ) 515 candidates = list( { idx for key in keys for idx in buckets.get( key, [] ) } ) 516 517 if not candidates : 518 return -1 519 520 dense = _to_dense( sparse_indices, sparse_values ).reshape( 1, n_input_states ) 521 candidate_vectors = _batch_to_dense( candidates ) 522 523 distances = af_jensenshannondivergence_cpu( candidate_vectors, dense, axis=1 ) 524 min_idx = np.argmin( distances ) 525 min_dist = distances[ min_idx ] 526 527 if min_dist < jsd_eps : 528 return candidates[ min_idx ] 529 530 return -1 531 532 # Map belief vector to course hash buckets for subsequent fast lookup 533 def _register( sparse_indices: np.ndarray, sparse_values: np.ndarray, idx: int): 534 for key in _coarse_keys( sparse_indices, sparse_values ): 535 buckets.setdefault( key, [] ).append( idx ) 536 537 pi_indicies, pi_values = _to_sparse( pi ) 538 sparse_belief_vector_indices.append( pi_indicies ) 539 sparse_belief_vector_values.append( pi_values ) 540 n_belief_vectors = 1 541 542 _register( pi_indicies, pi_values, 0 ) 543 544 frontier = deque([0]) 545 546 FRONTIER_CHUNK = 256 547 548 try: 549 while frontier and n_belief_vectors < exact_state_cap: 550 551 batch_indices = [] 552 while frontier and len(batch_indices) < FRONTIER_CHUNK: 553 batch_indices.append(frontier.popleft()) 554 555 B = len(batch_indices) 556 557 mus_device = xp.asarray(_batch_to_dense( batch_indices ) ) 558 flat_device = mus_device @ T_flat_device 559 all_mu_Tx_device = flat_device.reshape( B, n_symbols, n_input_states ) 560 probs_device = all_mu_Tx_device.sum(axis=2) 561 562 vb_device, vx_device = xp.where(probs_device > EPS) 563 564 if vb_device.size == 0: 565 del mus_device, flat_device, all_mu_Tx_device, probs_device, vb_device, vx_device 566 continue 567 568 vp_device = probs_device[vb_device, vx_device] 569 mu_next_device = all_mu_Tx_device[vb_device, vx_device] / vp_device[:, None] 570 571 mu_next_cpu = to_cpu( mu_next_device ) 572 573 vp_cpu = to_cpu( vp_device ) 574 vb_cpu = to_cpu( vb_device ) 575 vx_cpu = to_cpu( vx_device ) 576 577 # Free GPU arrays for this batch immediately after download 578 del mus_device, flat_device, all_mu_Tx_device, probs_device 579 del vb_device, vx_device, vp_device, mu_next_device 580 581 if not np.isfinite(mu_next_cpu).all(): 582 raise ValueError("Non-finite belief vectors in BFS batch") 583 584 for q in range( len( vb_cpu ) ): 585 586 indicies, values = _to_sparse(mu_next_cpu[q]) 587 orig = int( batch_indices[ int( vb_cpu[ q ] ) ] ) 588 idx = _lookup( indicies, values ) 589 590 if idx == -1: 591 592 idx = n_belief_vectors 593 sparse_belief_vector_indices.append( indicies ) 594 sparse_belief_vector_values.append( values ) 595 _register( indicies, values, idx ) 596 597 n_belief_vectors += 1 598 frontier.append(idx) 599 600 t_origin.append(orig) 601 t_target.append(idx) 602 t_prob.append(float(vp_cpu[q])) 603 t_symbol.append(int(vx_cpu[q])) 604 605 # -- Close the graph by mapping each dangling state to the nearest closed state 606 607 if frontier: 608 609 if verbose: 610 print(f"Max exact states ({exact_state_cap}) exceded ({n_belief_vectors}). Closing graph...") 611 612 dense_all = np.zeros((n_belief_vectors, n_input_states), dtype=np.float32) 613 614 for i in range(n_belief_vectors): 615 dense_all[i, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i].astype(np.float32) 616 617 ixp = xp if CAGRA_AVAILABLE else np 618 619 if CAGRA_AVAILABLE: 620 621 bv_device_f32 = xp.asarray(dense_all) 622 del dense_all 623 624 if verbose : 625 print("Building CAGRA index...") 626 627 index = cagra.build( 628 cagra.IndexParams(graph_degree=32, metric='sqeuclidean'), 629 bv_device_f32, 630 ) 631 632 else : 633 634 if verbose : 635 print("Building hnswlib index...") 636 637 T_flat_device = to_cpu( T_flat_device ) 638 bv_device_f32 = dense_all 639 index = hnswlib.Index(space='l2', dim=bv_device_f32[0].size) 640 index.init_index(max_elements=len(bv_device_f32), ef_construction=500, M=50) 641 index.set_ef(50) 642 index.add_items(bv_device_f32) 643 644 if verbose: 645 print("Index built.") 646 647 closure_buckets: dict[bytes, list[tuple]] = {} 648 max_error = 0.0 649 avg_error = 0.0 650 n_mapped = 0 651 652 def _closure_lookup( indicies: np.ndarray, values: np.ndarray) -> int: 653 seen = set() 654 for key in _coarse_keys( indicies, values): 655 for entry in closure_buckets.get(key, []): 656 eid = id(entry) 657 if eid not in seen: 658 seen.add(eid) 659 s_idx, s_val, target = entry 660 if _jsd_sparse(indicies, values, s_idx, s_val) < jsd_eps: 661 return target 662 return -1 663 664 def _closure_register(indicies: np.ndarray, values: np.ndarray, target: int): 665 entry = (indicies, values, target) 666 for key in _coarse_keys(indicies, values): 667 closure_buckets.setdefault(key, []).append(entry) 668 669 CHUNK = 10_000 670 frontier_list = list(frontier) 671 672 for start in range(0, len(frontier_list), CHUNK): 673 674 chunk_idxs = frontier_list[start : start + CHUNK] 675 C = len(chunk_idxs) 676 677 if CAGRA_AVAILABLE : 678 chunk_device = cp.asarray(_batch_to_dense(chunk_idxs)) 679 else : 680 chunk_device = _batch_to_dense(chunk_idxs) 681 682 flat_device = chunk_device @ T_flat_device 683 684 mu_Tx_device = flat_device.reshape(C, n_symbols, n_input_states) 685 686 probs_device = mu_Tx_device.sum(axis=2) 687 688 # Free intermediates we no longer need 689 del chunk_device, flat_device 690 691 vf_device, vx_device = ixp.where(probs_device > EPS) 692 693 if vf_device.size == 0: 694 del mu_Tx_device, probs_device, vf_device, vx_device 695 continue 696 697 vp_device = probs_device[vf_device, vx_device] 698 mu_v_device = mu_Tx_device[vf_device, vx_device] / vp_device[:, None] 699 700 # Free now-consumed GPU arrays 701 del mu_Tx_device, probs_device 702 703 mu_v_cpu = to_cpu( mu_v_device ) 704 vp_cpu = to_cpu( vp_device ) 705 vf_cpu = to_cpu( vf_device ) 706 vx_cpu = to_cpu( vx_device ) 707 708 del vf_device, vx_device, vp_device 709 710 # float32 copy for CAGRA; free float64 immediately after cast 711 unknown_device = mu_v_device.astype(ixp.float32) 712 del mu_v_device 713 714 N = len(vf_cpu) 715 orig_arr = np.array([chunk_idxs[f] for f in vf_cpu], dtype=np.int32) 716 next_idx = np.full(N, -1, dtype=np.int32) 717 unknown_mask = np.ones(N, dtype=bool) 718 719 sparse_cache = [_to_sparse(mu_v_cpu[q]) for q in range(N)] 720 721 for q in range(N): 722 indicies, values = sparse_cache[q] 723 target = _closure_lookup(indicies, values) 724 if target != -1: 725 next_idx[q] = target 726 unknown_mask[q] = False 727 728 if unknown_mask.any(): 729 730 unknown_rows = unknown_device[ixp.asarray(unknown_mask)] 731 732 if CAGRA_AVAILABLE : 733 734 if verbose: 735 print(f"CARGA search for {unknown_mask.sum()} unknown states...") 736 737 _, indices_device = cagra.search( 738 cagra.SearchParams(itopk_size=128), 739 index, unknown_rows, k_ann, 740 ) 741 742 else : 743 if verbose : 744 print(f"hnswlib search for {unknown_mask.sum()} unknown states...") 745 746 indices_device, _ = index.knn_query( to_cpu( unknown_rows ), k=k_ann ) 747 748 output_indices = ixp.asarray( indices_device ) 749 750 min_dist = ixp.full(unknown_rows.shape[0], ixp.inf, dtype=ixp.float32) 751 min_ks = ixp.zeros(unknown_rows.shape[0], dtype=ixp.int32) 752 753 for idx_k in range(output_indices.shape[1]): 754 sv = bv_device_f32[output_indices[:, idx_k], :] 755 756 if CAGRA_AVAILABLE : 757 d = af_jensenshannondivergence_device(unknown_rows, sv, axis=1) 758 else : 759 d = af_jensenshannondivergence_cpu(unknown_rows, sv, axis=1) 760 761 mask = d < min_dist 762 min_dist = ixp.where(mask, d, min_dist) 763 min_ks = ixp.where(mask, idx_k, min_ks) 764 765 max_dist = to_cpu( ixp.max(min_dist) ) 766 mean_dist = to_cpu( ixp.mean(min_dist) ) 767 768 n_mapped += N 769 avg_error += mean_dist 770 max_error = max(max_dist, max_error) 771 772 resolved = to_cpu( output_indices[ixp.arange(unknown_rows.shape[0]), min_ks] ) 773 unknown_qs = np.where(unknown_mask)[0] 774 775 for i, q in enumerate(unknown_qs): 776 nz_idx, nz_val = sparse_cache[q] 777 next_idx[q] = int(resolved[i]) 778 _closure_register(nz_idx, nz_val, int(resolved[i])) 779 780 # Free ANN-related GPU arrays 781 del unknown_rows, indices_device, output_indices 782 del min_dist, min_ks 783 784 del unknown_device 785 786 for q in range(N): 787 t_origin.append(int(orig_arr[q])) 788 t_target.append(int(next_idx[q])) 789 t_prob.append(float(vp_cpu[q])) 790 t_symbol.append(int(vx_cpu[q])) 791 792 # -- Free CAGRA index and float32 belief matrix -------------------- 793 794 # CAGRA memory lives outside CuPy's pool and is not released by free_all_blocks() 795 del index 796 del bv_device_f32 797 798 if CUPY_AVAILABLE : 799 800 cp.get_default_memory_pool().free_all_blocks() 801 cp.get_default_pinned_memory_pool().free_all_blocks() 802 803 gc.collect() 804 805 if verbose: 806 print(f"Closed: {n_mapped} dangling states") 807 print(f"Avg JSD state closure error: {(1.0 / np.log(2)) * avg_error / n_mapped} bits") 808 print(f"Max JSD state closure error: {(1.0 / np.log(2)) * max_error} bits") 809 810 # ------------------ Create the MSP --------------------# 811 812 msp_transitions = [ 813 Transition( 814 origin_state_idx = t_origin[i], 815 target_state_idx = t_target[i], 816 prob = t_prob[i], 817 symbol_idx = t_symbol[i], 818 ) 819 for i in range(len(t_origin)) 820 ] 821 822 mixed_state_vectors = list(_batch_to_dense(list(range(n_belief_vectors)))) 823 824 if verbose: 825 print("Assembled msp_transitions.") 826 827 msp = MSP( 828 states = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)], 829 mixed_state_vectors = mixed_state_vectors, 830 transitions = msp_transitions, 831 alphabet = alphabet.copy(), 832 ) 833 834 if verbose: 835 print("\nDone.\n") 836 837 return msp 838 839 finally: 840 841 # ------------------ Cleanup --------------------# 842 843 if 'T_flat_device' in dir(): 844 del T_flat_device 845 846 if 'index' in locals(): 847 del index 848 849 if 'bv_device_f32' in locals(): 850 del bv_device_f32 851 852 if CUPY_AVAILABLE : 853 854 cp.get_default_memory_pool().free_all_blocks() 855 cp.get_default_pinned_memory_pool().free_all_blocks() 856 857 gc.collect()
def
to_cpu(array):
class
MSP:
55class MSP: 56 """ 57 Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity. 58 59 Computes E (excess entropy), S (synchronization information), and T (transient information) 60 using spectral decomposition of the fundamental matrix. 61 62 See 63 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 64 https://arxiv.org/abs/1309.3792 65 """ 66 67 def __init__( 68 self, 69 states, 70 mixed_state_vectors, 71 transitions, 72 alphabet, 73 start_state_idx: int = 0, 74 gmres_rtol: float = 1e-10, 75 gmres_maxiter: int = 1000, 76 eps: float = 1e-12, 77 ): 78 self.states = states 79 self.mixed_state_vectors = np.array(mixed_state_vectors) 80 self.transitions = transitions 81 self.alphabet = alphabet 82 self.start_state_idx = start_state_idx 83 self.gmres_rtol = gmres_rtol 84 self.gmres_maxiter = gmres_maxiter 85 self.EPS = eps 86 self._cache = {} 87 88 M = len(states) 89 if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M: 90 raise ValueError( 91 f"mixed_state_vectors must have shape (M={M}, N_causal), " 92 f"got {self.mixed_state_vectors.shape}." 93 ) 94 if not (0 <= start_state_idx < M): 95 raise ValueError( 96 f"start_state_idx={start_state_idx} is out of range for M={M} states." 97 ) 98 if gmres_rtol <= 0: 99 raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.") 100 if eps <= 0: 101 raise ValueError(f"eps must be positive, got {eps}.") 102 103 def _build_W(self) -> sp_sparse.csr_matrix: 104 """ 105 Build and return the sparse row-stochastic transition matrix W (M × M). 106 Rows are re-normalised to absorb floating-point drift. 107 """ 108 if "W" in self._cache: 109 return self._cache["W"] 110 111 M = len(self.states) 112 113 rows = [tr.origin_state_idx for tr in self.transitions] 114 cols = [tr.target_state_idx for tr in self.transitions] 115 data = [tr.prob for tr in self.transitions] 116 117 W = sp_sparse.csr_matrix((data, (rows, cols)), shape=(M, M)) 118 119 row_sums = np.array(W.sum(axis=1)).ravel() 120 121 if np.any(row_sums < self.EPS): 122 bad = np.where(row_sums < self.EPS)[0].tolist() 123 raise ValueError( 124 f"States at indices {bad} have zero total outgoing probability. " 125 "Check that every MSP state has at least one outgoing transition." 126 ) 127 128 W = (sp_sparse.diags(1.0 / row_sums) @ W).tocsr() 129 130 self._cache["W"] = W 131 return W 132 133 def _compute_pi_W(self, W) -> np.ndarray: 134 """ 135 Solve for the stationary distribution of W via GPU GMRES, then free 136 all intermediate GPU arrays before returning the CPU result. 137 """ 138 if "pi_W" in self._cache: 139 return self._cache["pi_W"] 140 141 xp = cp if CUPY_AVAILABLE else np 142 143 xp_sparse = cpx_sparse if CUPY_AVAILABLE else sp_sparse 144 xp_linalg = cpx_linalg if CUPY_AVAILABLE else sp_linalg 145 146 W_device = cpx_sparse.csr_matrix(W) 147 M = W_device.shape[0] 148 149 I_device = xp_sparse.eye(M, format="csr", dtype=W_device.dtype) 150 A_device = (W_device.T - I_device).tocsr() 151 152 ones_row = xp_sparse.csr_matrix(xp.ones((1, M), dtype=W_device.dtype)) 153 A_aug = xp_sparse.vstack([A_device[:-1, :], ones_row]).tocsr() 154 155 b_device = xp.zeros(M, dtype=W_device.dtype) 156 b_device[-1] = 1.0 157 158 diag_vals = A_aug.diagonal() 159 diag_vals[diag_vals == 0] = 1.0 160 M_inv = xp_sparse.diags(1.0 / diag_vals) 161 162 pi_W_device, info = xp_linalg.gmres( 163 A_aug, b_device, 164 M=M_inv, 165 restart=100, 166 maxiter=1000, 167 ) 168 169 if info > 0: 170 print(f"Warning: GMRES did not converge after {info} iterations") 171 elif info < 0: 172 raise ValueError("GMRES failed due to illegal input or breakdown.") 173 174 pi_W_device = xp.clip(pi_W_device, 0.0, None) 175 s = pi_W_device.sum() 176 177 if s < self.EPS: 178 raise ValueError("Stationary distribution collapsed.") 179 180 pi_W_device /= s 181 182 # Download before freeing 183 result = to_cpu( pi_W_device ) 184 185 # -- Free every GPU object created in this method ---------------------- 186 del W_device, I_device, A_device, ones_row, A_aug, b_device, M_inv, diag_vals, pi_W_device 187 188 if CUPY_AVAILABLE : 189 cp.get_default_memory_pool().free_all_blocks() 190 cp.get_default_pinned_memory_pool().free_all_blocks() 191 192 gc.collect() 193 194 self._cache["pi_W"] = result 195 196 return result 197 198 def _fundamental( 199 self, 200 W: sp_sparse.csr_matrix, 201 pi_W: np.ndarray, 202 ) -> tuple[sp_linalg.LinearOperator, np.ndarray, np.ndarray, LinearOperator]: 203 """ 204 Compute the fundamental matrix 205 """ 206 if "fundamental" in self._cache: 207 return self._cache["fundamental"] 208 209 M = W.shape[0] 210 WT = W.T.tocsc() 211 212 e0 = np.zeros(M) 213 e0[self.start_state_idx] = 1.0 214 215 def Q_T_matvec(v: np.ndarray) -> np.ndarray: 216 return v - WT @ v + pi_W * v.sum() 217 218 Q_LO = sp_linalg.LinearOperator((M, M), matvec=Q_T_matvec, dtype=float) 219 220 W_diag = np.array(W.diagonal()) 221 M_diag = 1.0 - W_diag + pi_W 222 M_diag = np.where(np.abs(M_diag) > self.EPS, M_diag, 1.0) 223 precond = sp_linalg.LinearOperator((M, M), matvec=lambda v: v / M_diag, dtype=float) 224 225 Z_row, exit_code = sp_linalg.lgmres( 226 Q_LO, 227 e0, 228 M=precond, 229 rtol=self.gmres_rtol, 230 atol=self.gmres_rtol, 231 maxiter=self.gmres_maxiter, 232 ) 233 if exit_code != 0: 234 raise RuntimeError( 235 f"GMRES failed on the fundamental matrix solve (exit code {exit_code}). " 236 "Try increasing gmres_maxiter, loosening gmres_rtol, or verify that " 237 "the transition matrix is irreducible." 238 ) 239 240 fund_row = Z_row - pi_W 241 242 self._cache["fundamental"] = (Q_LO, Z_row, fund_row, precond) 243 return Q_LO, Z_row, fund_row, precond 244 245 def _get_H_WA(self) -> np.ndarray: 246 """ 247 Per-state Shannon entropy of the output symbol distribution, 248 marginalised over target states. Rows are normalised before use. 249 """ 250 if "H_WA" in self._cache: 251 return self._cache["H_WA"] 252 253 M, A = len(self.states), len(self.alphabet) 254 state_sym = np.zeros((M, A)) 255 for tr in self.transitions: 256 state_sym[tr.origin_state_idx, tr.symbol_idx] += tr.prob 257 258 sym_row_sums = state_sym.sum(axis=1, keepdims=True) 259 sym_row_sums = np.where(sym_row_sums > self.EPS, sym_row_sums, 1.0) 260 state_sym /= sym_row_sums 261 262 safe = state_sym > self.EPS 263 log_vals = np.where(safe, np.log2(np.where(safe, state_sym, 1.0)), 0.0) 264 H_WA = -np.sum(state_sym * log_vals, axis=1) 265 266 self._cache["H_WA"] = H_WA 267 return H_WA 268 269 def _get_H_eta(self) -> np.ndarray: 270 """ 271 Per-state Shannon entropy of the mixed-state (belief) vector. 272 Rows are normalised before entropy computation. 273 """ 274 if "H_eta" in self._cache: 275 return self._cache["H_eta"] 276 277 msv = self.mixed_state_vectors.copy().astype(float) 278 row_sums = msv.sum(axis=1, keepdims=True) 279 row_sums = np.where(row_sums > self.EPS, row_sums, 1.0) 280 msv /= row_sums 281 282 safe = msv > self.EPS 283 log_vals = np.where(safe, np.log2(np.where(safe, msv, 1.0)), 0.0) 284 H_eta = -np.sum(msv * log_vals, axis=1) 285 286 self._cache["H_eta"] = H_eta 287 return H_eta 288 289 def get_E_S_T(self) -> tuple[float, float, float]: 290 291 try: 292 W = self._build_W() 293 pi_W = self._compute_pi_W(W) # GPU memory freed inside _compute_pi_W 294 H_WA = self._get_H_WA() 295 H_eta = self._get_H_eta() 296 297 Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W) 298 299 E = float(fund_row @ H_WA) 300 S = float(fund_row @ H_eta) 301 302 ZZ_row, exit_code = sp_linalg.lgmres( 303 Q_LO, 304 Z_row, 305 x0=Z_row, 306 M=precond, 307 rtol=self.gmres_rtol, 308 atol=self.gmres_rtol, 309 maxiter=self.gmres_maxiter, 310 ) 311 if exit_code != 0: 312 raise RuntimeError( 313 f"GMRES failed on the Z² solve for T (exit code {exit_code}). " 314 "Try increasing gmres_maxiter or loosening gmres_rtol." 315 ) 316 317 t_row = ZZ_row - pi_W 318 T = float(t_row @ H_WA) 319 320 return E, S, T 321 322 finally: 323 324 if CUPY_AVAILABLE : 325 cp.get_default_memory_pool().free_all_blocks() 326 cp.get_default_pinned_memory_pool().free_all_blocks() 327 328 gc.collect() 329 330 def clear_cache(self) -> None: 331 self._cache.clear()
Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity.
Computes E (excess entropy), S (synchronization information), and T (transient information) using spectral decomposition of the fundamental matrix.
See Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792
MSP( states, mixed_state_vectors, transitions, alphabet, start_state_idx: int = 0, gmres_rtol: float = 1e-10, gmres_maxiter: int = 1000, eps: float = 1e-12)
67 def __init__( 68 self, 69 states, 70 mixed_state_vectors, 71 transitions, 72 alphabet, 73 start_state_idx: int = 0, 74 gmres_rtol: float = 1e-10, 75 gmres_maxiter: int = 1000, 76 eps: float = 1e-12, 77 ): 78 self.states = states 79 self.mixed_state_vectors = np.array(mixed_state_vectors) 80 self.transitions = transitions 81 self.alphabet = alphabet 82 self.start_state_idx = start_state_idx 83 self.gmres_rtol = gmres_rtol 84 self.gmres_maxiter = gmres_maxiter 85 self.EPS = eps 86 self._cache = {} 87 88 M = len(states) 89 if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M: 90 raise ValueError( 91 f"mixed_state_vectors must have shape (M={M}, N_causal), " 92 f"got {self.mixed_state_vectors.shape}." 93 ) 94 if not (0 <= start_state_idx < M): 95 raise ValueError( 96 f"start_state_idx={start_state_idx} is out of range for M={M} states." 97 ) 98 if gmres_rtol <= 0: 99 raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.") 100 if eps <= 0: 101 raise ValueError(f"eps must be positive, got {eps}.")
def
get_E_S_T(self) -> tuple[float, float, float]:
289 def get_E_S_T(self) -> tuple[float, float, float]: 290 291 try: 292 W = self._build_W() 293 pi_W = self._compute_pi_W(W) # GPU memory freed inside _compute_pi_W 294 H_WA = self._get_H_WA() 295 H_eta = self._get_H_eta() 296 297 Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W) 298 299 E = float(fund_row @ H_WA) 300 S = float(fund_row @ H_eta) 301 302 ZZ_row, exit_code = sp_linalg.lgmres( 303 Q_LO, 304 Z_row, 305 x0=Z_row, 306 M=precond, 307 rtol=self.gmres_rtol, 308 atol=self.gmres_rtol, 309 maxiter=self.gmres_maxiter, 310 ) 311 if exit_code != 0: 312 raise RuntimeError( 313 f"GMRES failed on the Z² solve for T (exit code {exit_code}). " 314 "Try increasing gmres_maxiter or loosening gmres_rtol." 315 ) 316 317 t_row = ZZ_row - pi_W 318 T = float(t_row @ H_WA) 319 320 return E, S, T 321 322 finally: 323 324 if CUPY_AVAILABLE : 325 cp.get_default_memory_pool().free_all_blocks() 326 cp.get_default_pinned_memory_pool().free_all_blocks() 327 328 gc.collect()
def
compute_msp_exact( T_x: list[list[list[fractions.Fraction]]], pi: tuple[fractions.Fraction], n_states: int, alphabet: list[str], exact_state_cap: int = 1000, verbose: bool = True):
334def compute_msp_exact( 335 T_x : list[ list[ list[ Fraction ] ] ], 336 pi : tuple[Fraction], 337 n_states : int, 338 alphabet : list[str], 339 exact_state_cap: int = 1000, 340 verbose: bool = True, 341): 342 343 n_symbols = len(alphabet) 344 345 def _apply(mu: tuple[Fraction, ...], x: int) : 346 347 Tx = T_x[x] 348 mu_Tx = [Fraction(0)] * n_states 349 for i, w in enumerate(mu): 350 if w == Fraction(0): 351 continue 352 row = Tx[i] 353 for j in range(n_states): 354 if row[j]: 355 mu_Tx[j] += w * row[j] 356 357 prob = sum(mu_Tx) 358 if prob == Fraction(0): 359 return None, Fraction(0) 360 361 mu_next = tuple(v / prob for v in mu_Tx) 362 return mu_next, prob 363 364 belief_vectors: list[tuple[Fraction, ...]] = [] 365 seen_states: dict[tuple, int] = {} 366 msp_transitions: list = [] 367 368 def _register(mu: tuple[Fraction, ...]) -> int: 369 idx = len(belief_vectors) 370 belief_vectors.append(mu) 371 seen_states[mu] = idx 372 return idx 373 374 start_mu = tuple(pi) 375 _register(start_mu) 376 frontier = deque([0]) 377 378 while frontier and len(belief_vectors) < exact_state_cap: 379 380 idx = frontier.popleft() 381 mu = belief_vectors[idx] 382 383 for x in range(n_symbols): 384 385 mu_next, prob = _apply(mu, x) 386 387 if mu_next is None: 388 continue 389 390 if mu_next not in seen_states: 391 next_idx = _register(mu_next) 392 frontier.append(next_idx) 393 else: 394 next_idx = seen_states[mu_next] 395 396 msp_transitions.append( 397 Transition( 398 origin_state_idx = idx, 399 target_state_idx = next_idx, 400 prob = float(prob), 401 symbol_idx = x, 402 pq = prob 403 ) 404 ) 405 406 if frontier: 407 raise RuntimeError( f"Exact state cap {exact_state_cap} exceded." ) 408 409 if verbose: 410 print(f"Exact MSP found: {len(belief_vectors)} states.") 411 412 msp = MSP( 413 states = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)], 414 mixed_state_vectors = mixed_state_vectors, 415 transitions = msp_transitions, 416 alphabet = alphabet.copy(), 417 ) 418 419 return msp
def
compute_msp( T_x: list[numpy.ndarray], pi: numpy.ndarray, n_states: int, alphabet: list[str], exact_state_cap: int = 175000, jsd_eps: float = 1e-07, k_ann: int = 50, EPS: float = 1e-12, verbose=True):
421def compute_msp( 422 T_x : list[np.ndarray], 423 pi : np.ndarray, 424 n_states : int, 425 alphabet : list[str], 426 exact_state_cap: int = 175_000, 427 jsd_eps: float = 1e-7, 428 k_ann: int = 50, 429 EPS : float = 1e-12, 430 verbose = True, 431) : 432 433 xp = cp if CUPY_AVAILABLE else np 434 435 n_input_states = T_x[0].shape[0] 436 n_symbols = len(alphabet) 437 438 T_flat_device = xp.asarray( 439 np.stack(T_x).transpose(1, 0, 2).reshape(n_input_states, n_symbols * n_input_states) 440 ) 441 442 # -- Sparse Belief Vectors ------------------------------------------- 443 444 sparse_belief_vector_indices: list[np.ndarray] = [] 445 sparse_belief_vector_values: list[np.ndarray] = [] 446 447 def _to_sparse(v: np.ndarray): 448 449 non_zero = np.where( np.abs( v ) > 1e-20 )[ 0 ] 450 return non_zero.astype(np.int32), v[ non_zero ] 451 452 def _to_dense( sparse_indices: np.ndarray, sparse_values: np.ndarray ): 453 dense = np.zeros( n_input_states ) 454 dense[ sparse_indices ] = sparse_values 455 return dense 456 457 def _batch_to_dense( indicies: list[int] ) -> np.ndarray: 458 459 out = np.zeros((len(indicies), n_input_states), dtype=np.float64) 460 for row, i in enumerate(indicies): 461 out[row, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i] 462 463 return out 464 465 # -- Transitions ---------------------------------------------- 466 467 t_origin: list[int] = [] 468 t_target: list[int] = [] 469 t_prob: list[float] = [] 470 t_symbol: list[int] = [] 471 472 # -- Sparse dual bucket coarse hash ---------------------------------- 473 474 # Since we use dual bins 475 # Some x, y in p, q needs to differ by more than 1/(2*coarse_scale) to not share a bin 476 # Thus minimal total variation between non-bin sharing distributions is 477 # To ensure jsd(p,q) < jsd_eps -> p,q share a bin, we use Pinsker's inequality 478 # 1 /(2*coarse_scale) <= 2*sqrt(2*jsd_eps) 479 # 1 /( 4*sqrt(2*jsd_eps) ) <= coarse_scale 480 481 coarse_scale = 1 / ( 4*np.sqrt( 2 * jsd_eps ) ) 482 483 def _coarse_keys( sparse_indices: np.ndarray, sparse_values: np.ndarray) -> list[bytes]: 484 scaled = _to_dense( sparse_indices, sparse_values*coarse_scale ) 485 return [ 486 np.round(scaled).astype(np.int32).tobytes(), 487 np.round(scaled + 0.5).astype(np.int32).tobytes() 488 ] 489 490 buckets: dict[bytes, list[int]] = {} 491 492 # Jenson Shannon Divergence for sparse vectors 493 def _jsd_sparse( 494 ai: np.ndarray, av: np.ndarray, 495 bi: np.ndarray, bv: np.ndarray, 496 ) -> float: 497 498 # get sorted union of the [non-zero] indices over both vectors 499 all_i = np.union1d( ai, bi ) 500 501 # form dense vectors 502 p = np.zeros( len( all_i ), dtype=np.float64 ) 503 q = np.zeros( len( all_i ), dtype=np.float64 ) 504 505 # np.searchsorted( all_i, ai ) finds position in all_i where each ai is 506 p[ np.searchsorted( all_i, ai ) ] = av 507 q[ np.searchsorted( all_i, bi ) ] = bv 508 509 # Jenson Shannon divergence (JDS) 510 return af_jensenshannondivergence_cpu( p, q ) 511 512 # Check if an existing belief vector is within jsd_eps 513 def _lookup( sparse_indices: np.ndarray, sparse_values: np.ndarray ) -> int: 514 515 keys = _coarse_keys( sparse_indices, sparse_values ) 516 candidates = list( { idx for key in keys for idx in buckets.get( key, [] ) } ) 517 518 if not candidates : 519 return -1 520 521 dense = _to_dense( sparse_indices, sparse_values ).reshape( 1, n_input_states ) 522 candidate_vectors = _batch_to_dense( candidates ) 523 524 distances = af_jensenshannondivergence_cpu( candidate_vectors, dense, axis=1 ) 525 min_idx = np.argmin( distances ) 526 min_dist = distances[ min_idx ] 527 528 if min_dist < jsd_eps : 529 return candidates[ min_idx ] 530 531 return -1 532 533 # Map belief vector to course hash buckets for subsequent fast lookup 534 def _register( sparse_indices: np.ndarray, sparse_values: np.ndarray, idx: int): 535 for key in _coarse_keys( sparse_indices, sparse_values ): 536 buckets.setdefault( key, [] ).append( idx ) 537 538 pi_indicies, pi_values = _to_sparse( pi ) 539 sparse_belief_vector_indices.append( pi_indicies ) 540 sparse_belief_vector_values.append( pi_values ) 541 n_belief_vectors = 1 542 543 _register( pi_indicies, pi_values, 0 ) 544 545 frontier = deque([0]) 546 547 FRONTIER_CHUNK = 256 548 549 try: 550 while frontier and n_belief_vectors < exact_state_cap: 551 552 batch_indices = [] 553 while frontier and len(batch_indices) < FRONTIER_CHUNK: 554 batch_indices.append(frontier.popleft()) 555 556 B = len(batch_indices) 557 558 mus_device = xp.asarray(_batch_to_dense( batch_indices ) ) 559 flat_device = mus_device @ T_flat_device 560 all_mu_Tx_device = flat_device.reshape( B, n_symbols, n_input_states ) 561 probs_device = all_mu_Tx_device.sum(axis=2) 562 563 vb_device, vx_device = xp.where(probs_device > EPS) 564 565 if vb_device.size == 0: 566 del mus_device, flat_device, all_mu_Tx_device, probs_device, vb_device, vx_device 567 continue 568 569 vp_device = probs_device[vb_device, vx_device] 570 mu_next_device = all_mu_Tx_device[vb_device, vx_device] / vp_device[:, None] 571 572 mu_next_cpu = to_cpu( mu_next_device ) 573 574 vp_cpu = to_cpu( vp_device ) 575 vb_cpu = to_cpu( vb_device ) 576 vx_cpu = to_cpu( vx_device ) 577 578 # Free GPU arrays for this batch immediately after download 579 del mus_device, flat_device, all_mu_Tx_device, probs_device 580 del vb_device, vx_device, vp_device, mu_next_device 581 582 if not np.isfinite(mu_next_cpu).all(): 583 raise ValueError("Non-finite belief vectors in BFS batch") 584 585 for q in range( len( vb_cpu ) ): 586 587 indicies, values = _to_sparse(mu_next_cpu[q]) 588 orig = int( batch_indices[ int( vb_cpu[ q ] ) ] ) 589 idx = _lookup( indicies, values ) 590 591 if idx == -1: 592 593 idx = n_belief_vectors 594 sparse_belief_vector_indices.append( indicies ) 595 sparse_belief_vector_values.append( values ) 596 _register( indicies, values, idx ) 597 598 n_belief_vectors += 1 599 frontier.append(idx) 600 601 t_origin.append(orig) 602 t_target.append(idx) 603 t_prob.append(float(vp_cpu[q])) 604 t_symbol.append(int(vx_cpu[q])) 605 606 # -- Close the graph by mapping each dangling state to the nearest closed state 607 608 if frontier: 609 610 if verbose: 611 print(f"Max exact states ({exact_state_cap}) exceded ({n_belief_vectors}). Closing graph...") 612 613 dense_all = np.zeros((n_belief_vectors, n_input_states), dtype=np.float32) 614 615 for i in range(n_belief_vectors): 616 dense_all[i, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i].astype(np.float32) 617 618 ixp = xp if CAGRA_AVAILABLE else np 619 620 if CAGRA_AVAILABLE: 621 622 bv_device_f32 = xp.asarray(dense_all) 623 del dense_all 624 625 if verbose : 626 print("Building CAGRA index...") 627 628 index = cagra.build( 629 cagra.IndexParams(graph_degree=32, metric='sqeuclidean'), 630 bv_device_f32, 631 ) 632 633 else : 634 635 if verbose : 636 print("Building hnswlib index...") 637 638 T_flat_device = to_cpu( T_flat_device ) 639 bv_device_f32 = dense_all 640 index = hnswlib.Index(space='l2', dim=bv_device_f32[0].size) 641 index.init_index(max_elements=len(bv_device_f32), ef_construction=500, M=50) 642 index.set_ef(50) 643 index.add_items(bv_device_f32) 644 645 if verbose: 646 print("Index built.") 647 648 closure_buckets: dict[bytes, list[tuple]] = {} 649 max_error = 0.0 650 avg_error = 0.0 651 n_mapped = 0 652 653 def _closure_lookup( indicies: np.ndarray, values: np.ndarray) -> int: 654 seen = set() 655 for key in _coarse_keys( indicies, values): 656 for entry in closure_buckets.get(key, []): 657 eid = id(entry) 658 if eid not in seen: 659 seen.add(eid) 660 s_idx, s_val, target = entry 661 if _jsd_sparse(indicies, values, s_idx, s_val) < jsd_eps: 662 return target 663 return -1 664 665 def _closure_register(indicies: np.ndarray, values: np.ndarray, target: int): 666 entry = (indicies, values, target) 667 for key in _coarse_keys(indicies, values): 668 closure_buckets.setdefault(key, []).append(entry) 669 670 CHUNK = 10_000 671 frontier_list = list(frontier) 672 673 for start in range(0, len(frontier_list), CHUNK): 674 675 chunk_idxs = frontier_list[start : start + CHUNK] 676 C = len(chunk_idxs) 677 678 if CAGRA_AVAILABLE : 679 chunk_device = cp.asarray(_batch_to_dense(chunk_idxs)) 680 else : 681 chunk_device = _batch_to_dense(chunk_idxs) 682 683 flat_device = chunk_device @ T_flat_device 684 685 mu_Tx_device = flat_device.reshape(C, n_symbols, n_input_states) 686 687 probs_device = mu_Tx_device.sum(axis=2) 688 689 # Free intermediates we no longer need 690 del chunk_device, flat_device 691 692 vf_device, vx_device = ixp.where(probs_device > EPS) 693 694 if vf_device.size == 0: 695 del mu_Tx_device, probs_device, vf_device, vx_device 696 continue 697 698 vp_device = probs_device[vf_device, vx_device] 699 mu_v_device = mu_Tx_device[vf_device, vx_device] / vp_device[:, None] 700 701 # Free now-consumed GPU arrays 702 del mu_Tx_device, probs_device 703 704 mu_v_cpu = to_cpu( mu_v_device ) 705 vp_cpu = to_cpu( vp_device ) 706 vf_cpu = to_cpu( vf_device ) 707 vx_cpu = to_cpu( vx_device ) 708 709 del vf_device, vx_device, vp_device 710 711 # float32 copy for CAGRA; free float64 immediately after cast 712 unknown_device = mu_v_device.astype(ixp.float32) 713 del mu_v_device 714 715 N = len(vf_cpu) 716 orig_arr = np.array([chunk_idxs[f] for f in vf_cpu], dtype=np.int32) 717 next_idx = np.full(N, -1, dtype=np.int32) 718 unknown_mask = np.ones(N, dtype=bool) 719 720 sparse_cache = [_to_sparse(mu_v_cpu[q]) for q in range(N)] 721 722 for q in range(N): 723 indicies, values = sparse_cache[q] 724 target = _closure_lookup(indicies, values) 725 if target != -1: 726 next_idx[q] = target 727 unknown_mask[q] = False 728 729 if unknown_mask.any(): 730 731 unknown_rows = unknown_device[ixp.asarray(unknown_mask)] 732 733 if CAGRA_AVAILABLE : 734 735 if verbose: 736 print(f"CARGA search for {unknown_mask.sum()} unknown states...") 737 738 _, indices_device = cagra.search( 739 cagra.SearchParams(itopk_size=128), 740 index, unknown_rows, k_ann, 741 ) 742 743 else : 744 if verbose : 745 print(f"hnswlib search for {unknown_mask.sum()} unknown states...") 746 747 indices_device, _ = index.knn_query( to_cpu( unknown_rows ), k=k_ann ) 748 749 output_indices = ixp.asarray( indices_device ) 750 751 min_dist = ixp.full(unknown_rows.shape[0], ixp.inf, dtype=ixp.float32) 752 min_ks = ixp.zeros(unknown_rows.shape[0], dtype=ixp.int32) 753 754 for idx_k in range(output_indices.shape[1]): 755 sv = bv_device_f32[output_indices[:, idx_k], :] 756 757 if CAGRA_AVAILABLE : 758 d = af_jensenshannondivergence_device(unknown_rows, sv, axis=1) 759 else : 760 d = af_jensenshannondivergence_cpu(unknown_rows, sv, axis=1) 761 762 mask = d < min_dist 763 min_dist = ixp.where(mask, d, min_dist) 764 min_ks = ixp.where(mask, idx_k, min_ks) 765 766 max_dist = to_cpu( ixp.max(min_dist) ) 767 mean_dist = to_cpu( ixp.mean(min_dist) ) 768 769 n_mapped += N 770 avg_error += mean_dist 771 max_error = max(max_dist, max_error) 772 773 resolved = to_cpu( output_indices[ixp.arange(unknown_rows.shape[0]), min_ks] ) 774 unknown_qs = np.where(unknown_mask)[0] 775 776 for i, q in enumerate(unknown_qs): 777 nz_idx, nz_val = sparse_cache[q] 778 next_idx[q] = int(resolved[i]) 779 _closure_register(nz_idx, nz_val, int(resolved[i])) 780 781 # Free ANN-related GPU arrays 782 del unknown_rows, indices_device, output_indices 783 del min_dist, min_ks 784 785 del unknown_device 786 787 for q in range(N): 788 t_origin.append(int(orig_arr[q])) 789 t_target.append(int(next_idx[q])) 790 t_prob.append(float(vp_cpu[q])) 791 t_symbol.append(int(vx_cpu[q])) 792 793 # -- Free CAGRA index and float32 belief matrix -------------------- 794 795 # CAGRA memory lives outside CuPy's pool and is not released by free_all_blocks() 796 del index 797 del bv_device_f32 798 799 if CUPY_AVAILABLE : 800 801 cp.get_default_memory_pool().free_all_blocks() 802 cp.get_default_pinned_memory_pool().free_all_blocks() 803 804 gc.collect() 805 806 if verbose: 807 print(f"Closed: {n_mapped} dangling states") 808 print(f"Avg JSD state closure error: {(1.0 / np.log(2)) * avg_error / n_mapped} bits") 809 print(f"Max JSD state closure error: {(1.0 / np.log(2)) * max_error} bits") 810 811 # ------------------ Create the MSP --------------------# 812 813 msp_transitions = [ 814 Transition( 815 origin_state_idx = t_origin[i], 816 target_state_idx = t_target[i], 817 prob = t_prob[i], 818 symbol_idx = t_symbol[i], 819 ) 820 for i in range(len(t_origin)) 821 ] 822 823 mixed_state_vectors = list(_batch_to_dense(list(range(n_belief_vectors)))) 824 825 if verbose: 826 print("Assembled msp_transitions.") 827 828 msp = MSP( 829 states = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)], 830 mixed_state_vectors = mixed_state_vectors, 831 transitions = msp_transitions, 832 alphabet = alphabet.copy(), 833 ) 834 835 if verbose: 836 print("\nDone.\n") 837 838 return msp 839 840 finally: 841 842 # ------------------ Cleanup --------------------# 843 844 if 'T_flat_device' in dir(): 845 del T_flat_device 846 847 if 'index' in locals(): 848 del index 849 850 if 'bv_device_f32' in locals(): 851 del bv_device_f32 852 853 if CUPY_AVAILABLE : 854 855 cp.get_default_memory_pool().free_all_blocks() 856 cp.get_default_pinned_memory_pool().free_all_blocks() 857 858 gc.collect()