amachine.am_msp

  1from __future__ import annotations
  2
  3import copy
  4from fractions import Fraction
  5import gc
  6import warnings
  7from collections import deque
  8
  9import numpy as np
 10import scipy.sparse as sp_sparse
 11import scipy.sparse.linalg as sp_linalg
 12
 13from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu
 14
 15# --------- Conditionally use cupy & cupyx if available else numpy and scipy
 16
 17try : 
 18	import cupy as cp
 19	CUPY_AVAILABLE = True
 20except ImportError :
 21	CUPY_AVAILABLE = False
 22
 23if CUPY_AVAILABLE :
 24	import cupyx.scipy.sparse as cpx_sparse
 25	import cupyx.scipy.sparse.linalg as cpx_linalg
 26	from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence_device
 27else :
 28	cp = None
 29	cpx_sparse = None
 30	cpx_linalg = None
 31
 32def to_cpu( array ):
 33    return array.get() if CUPY_AVAILABLE and hasattr( array, 'get' ) else array
 34
 35# --------- Conditionally use cagra if available else hnswlib
 36
 37if CUPY_AVAILABLE :
 38	try :
 39		from cuvs.neighbors import cagra
 40		CAGRA_AVAILABLE = True
 41	except ImportError :
 42		CAGRA_AVAILABLE = False
 43
 44if not CAGRA_AVAILABLE :
 45	import hnswlib
 46else :
 47	hnswlib = None
 48
 49# -------------------------------------------------------------
 50
 51from .am_causal_state import CausalState
 52from .am_transition import Transition
 53
 54class MSP:
 55	"""
 56	Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity.
 57 
 58	Computes E (excess entropy),  S (synchronization information), and T (transient information) 
 59	using spectral decomposition of the fundamental matrix.
 60 
 61	See
 62		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 63		https://arxiv.org/abs/1309.3792
 64	"""
 65
 66	def __init__(
 67		self,
 68		states,
 69		mixed_state_vectors,
 70		transitions,
 71		alphabet,
 72		start_state_idx: int   = 0,
 73		gmres_rtol:      float = 1e-10,
 74		gmres_maxiter:   int   = 1000,
 75		eps:             float = 1e-12,
 76	):
 77		self.states              = states
 78		self.mixed_state_vectors = np.array(mixed_state_vectors)
 79		self.transitions         = transitions
 80		self.alphabet            = alphabet
 81		self.start_state_idx     = start_state_idx
 82		self.gmres_rtol          = gmres_rtol
 83		self.gmres_maxiter       = gmres_maxiter
 84		self.EPS                 = eps
 85		self._cache              = {}
 86 
 87		M = len(states)
 88		if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M:
 89			raise ValueError(
 90				f"mixed_state_vectors must have shape (M={M}, N_causal), "
 91				f"got {self.mixed_state_vectors.shape}."
 92			)
 93		if not (0 <= start_state_idx < M):
 94			raise ValueError(
 95				f"start_state_idx={start_state_idx} is out of range for M={M} states."
 96			)
 97		if gmres_rtol <= 0:
 98			raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.")
 99		if eps <= 0:
100			raise ValueError(f"eps must be positive, got {eps}.")
101 
102	def _build_W(self) -> sp_sparse.csr_matrix:
103		"""
104		Build and return the sparse row-stochastic transition matrix W (M × M).
105		Rows are re-normalised to absorb floating-point drift.
106		"""
107		if "W" in self._cache:
108			return self._cache["W"]
109 
110		M    = len(self.states)
111		
112		rows = [tr.origin_state_idx for tr in self.transitions]
113		cols = [tr.target_state_idx for tr in self.transitions]
114		data = [tr.prob              for tr in self.transitions]
115		
116		W    = sp_sparse.csr_matrix((data, (rows, cols)), shape=(M, M))
117 
118		row_sums = np.array(W.sum(axis=1)).ravel()
119		
120		if np.any(row_sums < self.EPS):
121			bad = np.where(row_sums < self.EPS)[0].tolist()
122			raise ValueError(
123				f"States at indices {bad} have zero total outgoing probability. "
124				"Check that every MSP state has at least one outgoing transition."
125			)
126		
127		W = (sp_sparse.diags(1.0 / row_sums) @ W).tocsr()
128 
129		self._cache["W"] = W
130		return W
131 
132	def _compute_pi_W(self, W) -> np.ndarray:
133		"""
134		Solve for the stationary distribution of W via GPU GMRES, then free
135		all intermediate GPU arrays before returning the CPU result.
136		"""
137		if "pi_W" in self._cache:
138			return self._cache["pi_W"]
139 
140		xp = cp if CUPY_AVAILABLE else np
141		
142		xp_sparse = cpx_sparse if CUPY_AVAILABLE else sp_sparse
143		xp_linalg = cpx_linalg if CUPY_AVAILABLE else sp_linalg
144
145		W_device = cpx_sparse.csr_matrix(W)
146		M     = W_device.shape[0]
147 
148		I_device  = xp_sparse.eye(M, format="csr", dtype=W_device.dtype)
149		A_device  = (W_device.T - I_device).tocsr()
150 
151		ones_row = xp_sparse.csr_matrix(xp.ones((1, M), dtype=W_device.dtype))
152		A_aug    = xp_sparse.vstack([A_device[:-1, :], ones_row]).tocsr()
153 
154		b_device = xp.zeros(M, dtype=W_device.dtype)
155		b_device[-1] = 1.0
156 
157		diag_vals = A_aug.diagonal()
158		diag_vals[diag_vals == 0] = 1.0
159		M_inv = xp_sparse.diags(1.0 / diag_vals)
160 
161		pi_W_device, info = xp_linalg.gmres(
162			A_aug, b_device,
163			M=M_inv,
164			restart=100,
165			maxiter=1000,
166		)
167 
168		if info > 0:
169			print(f"Warning: GMRES did not converge after {info} iterations")
170		elif info < 0:
171			raise ValueError("GMRES failed due to illegal input or breakdown.")
172 
173		pi_W_device = xp.clip(pi_W_device, 0.0, None)
174		s = pi_W_device.sum()
175		
176		if s < self.EPS:
177			raise ValueError("Stationary distribution collapsed.")
178		
179		pi_W_device /= s
180 
181		# Download before freeing
182		result = to_cpu( pi_W_device )
183 
184		# -- Free every GPU object created in this method ----------------------
185		del W_device, I_device, A_device, ones_row, A_aug, b_device, M_inv, diag_vals, pi_W_device
186
187		if CUPY_AVAILABLE :
188			cp.get_default_memory_pool().free_all_blocks()
189			cp.get_default_pinned_memory_pool().free_all_blocks()
190		
191		gc.collect()
192 
193		self._cache["pi_W"] = result
194
195		return result
196 
197	def _fundamental(
198		self,
199		W:     sp_sparse.csr_matrix,
200		pi_W:  np.ndarray,
201	) -> tuple[sp_linalg.LinearOperator, np.ndarray, np.ndarray, LinearOperator]:
202		"""
203		Compute the fundamental matrix
204		"""
205		if "fundamental" in self._cache:
206			return self._cache["fundamental"]
207 
208		M  = W.shape[0]
209		WT = W.T.tocsc()
210 
211		e0          = np.zeros(M)
212		e0[self.start_state_idx] = 1.0
213 
214		def Q_T_matvec(v: np.ndarray) -> np.ndarray:
215			return v - WT @ v + pi_W * v.sum()
216 
217		Q_LO = sp_linalg.LinearOperator((M, M), matvec=Q_T_matvec, dtype=float)
218 
219		W_diag  = np.array(W.diagonal())
220		M_diag  = 1.0 - W_diag + pi_W
221		M_diag  = np.where(np.abs(M_diag) > self.EPS, M_diag, 1.0)
222		precond = sp_linalg.LinearOperator((M, M), matvec=lambda v: v / M_diag, dtype=float)
223 
224		Z_row, exit_code = sp_linalg.lgmres(
225			Q_LO,
226			e0,
227			M=precond,
228			rtol=self.gmres_rtol,
229			atol=self.gmres_rtol,
230			maxiter=self.gmres_maxiter,
231		)
232		if exit_code != 0:
233			raise RuntimeError(
234				f"GMRES failed on the fundamental matrix solve (exit code {exit_code}). "
235				"Try increasing gmres_maxiter, loosening gmres_rtol, or verify that "
236				"the transition matrix is irreducible."
237			)
238 
239		fund_row = Z_row - pi_W
240 
241		self._cache["fundamental"] = (Q_LO, Z_row, fund_row, precond)
242		return Q_LO, Z_row, fund_row, precond
243 
244	def _get_H_WA(self) -> np.ndarray:
245		"""
246		Per-state Shannon entropy of the output symbol distribution,
247		marginalised over target states. Rows are normalised before use.
248		"""
249		if "H_WA" in self._cache:
250			return self._cache["H_WA"]
251 
252		M, A       = len(self.states), len(self.alphabet)
253		state_sym  = np.zeros((M, A))
254		for tr in self.transitions:
255			state_sym[tr.origin_state_idx, tr.symbol_idx] += tr.prob
256 
257		sym_row_sums = state_sym.sum(axis=1, keepdims=True)
258		sym_row_sums = np.where(sym_row_sums > self.EPS, sym_row_sums, 1.0)
259		state_sym   /= sym_row_sums
260 
261		safe     = state_sym > self.EPS
262		log_vals = np.where(safe, np.log2(np.where(safe, state_sym, 1.0)), 0.0)
263		H_WA     = -np.sum(state_sym * log_vals, axis=1)
264 
265		self._cache["H_WA"] = H_WA
266		return H_WA
267 
268	def _get_H_eta(self) -> np.ndarray:
269		"""
270		Per-state Shannon entropy of the mixed-state (belief) vector.
271		Rows are normalised before entropy computation.
272		"""
273		if "H_eta" in self._cache:
274			return self._cache["H_eta"]
275 
276		msv      = self.mixed_state_vectors.copy().astype(float)
277		row_sums = msv.sum(axis=1, keepdims=True)
278		row_sums = np.where(row_sums > self.EPS, row_sums, 1.0)
279		msv     /= row_sums
280 
281		safe     = msv > self.EPS
282		log_vals = np.where(safe, np.log2(np.where(safe, msv, 1.0)), 0.0)
283		H_eta    = -np.sum(msv * log_vals, axis=1)
284 
285		self._cache["H_eta"] = H_eta
286		return H_eta
287 
288	def get_E_S_T(self) -> tuple[float, float, float]:
289
290		try:
291			W    = self._build_W()
292			pi_W = self._compute_pi_W(W)   # GPU memory freed inside _compute_pi_W
293			H_WA  = self._get_H_WA()
294			H_eta = self._get_H_eta()
295 
296			Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W)
297 
298			E = float(fund_row @ H_WA)
299			S = float(fund_row @ H_eta)
300 
301			ZZ_row, exit_code = sp_linalg.lgmres(
302				Q_LO,
303				Z_row,
304				x0=Z_row,
305				M=precond,
306				rtol=self.gmres_rtol,
307				atol=self.gmres_rtol,
308				maxiter=self.gmres_maxiter,
309			)
310			if exit_code != 0:
311				raise RuntimeError(
312					f"GMRES failed on the Z² solve for T (exit code {exit_code}). "
313					"Try increasing gmres_maxiter or loosening gmres_rtol."
314				)
315 
316			t_row = ZZ_row - pi_W
317			T     = float(t_row @ H_WA)
318
319			return E, S, T
320 
321		finally:
322
323			if CUPY_AVAILABLE :
324				cp.get_default_memory_pool().free_all_blocks()
325				cp.get_default_pinned_memory_pool().free_all_blocks()
326
327			gc.collect()
328 
329	def clear_cache(self) -> None:
330		self._cache.clear()
331
332
333def compute_msp_exact(
334	T_x : list[ list[ list[ Fraction ] ] ],
335	pi  : tuple[Fraction],
336	n_states : int,
337	alphabet : list[str],
338	exact_state_cap: int = 1000,
339	verbose: bool = True,
340):
341
342	n_symbols  = len(alphabet)
343
344	def _apply(mu: tuple[Fraction, ...], x: int) :
345
346		Tx = T_x[x]
347		mu_Tx = [Fraction(0)] * n_states
348		for i, w in enumerate(mu):
349			if w == Fraction(0):
350				continue
351			row = Tx[i]
352			for j in range(n_states):
353				if row[j]:
354					mu_Tx[j] += w * row[j]
355
356		prob = sum(mu_Tx)
357		if prob == Fraction(0):
358			return None, Fraction(0)
359
360		mu_next = tuple(v / prob for v in mu_Tx)
361		return mu_next, prob
362
363	belief_vectors: list[tuple[Fraction, ...]] = []
364	seen_states:    dict[tuple, int]           = {}
365	msp_transitions: list                      = []
366
367	def _register(mu: tuple[Fraction, ...]) -> int:
368		idx = len(belief_vectors)
369		belief_vectors.append(mu)
370		seen_states[mu] = idx
371		return idx
372
373	start_mu = tuple(pi)
374	_register(start_mu)
375	frontier = deque([0])
376
377	while frontier and len(belief_vectors) < exact_state_cap:
378		
379		idx = frontier.popleft()
380		mu  = belief_vectors[idx]
381
382		for x in range(n_symbols):
383
384			mu_next, prob = _apply(mu, x)
385
386			if mu_next is None:
387				continue
388
389			if mu_next not in seen_states:
390				next_idx = _register(mu_next)
391				frontier.append(next_idx)
392			else:
393				next_idx = seen_states[mu_next]
394
395			msp_transitions.append(
396				Transition(
397					origin_state_idx = idx,
398					target_state_idx = next_idx,
399					prob             = float(prob),
400					symbol_idx       = x,
401					pq               = prob
402				)
403			)
404
405	if frontier:
406		raise RuntimeError( f"Exact state cap {exact_state_cap} exceded." )
407
408	if verbose:
409		print(f"Exact MSP found: {len(belief_vectors)} states.")
410
411	msp = MSP(
412		states              = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)],
413		mixed_state_vectors = mixed_state_vectors,
414		transitions         = msp_transitions,
415		alphabet            = alphabet.copy(),
416	)
417
418	return msp
419
420def compute_msp(
421	T_x : list[np.ndarray],
422	pi : np.ndarray,
423	n_states : int,
424	alphabet : list[str],
425	exact_state_cap: int = 175_000,
426	jsd_eps: float = 1e-7,
427	k_ann: int   = 50,
428	EPS : float = 1e-12,
429	verbose = True,
430) :
431
432	xp = cp if CUPY_AVAILABLE else np
433
434	n_input_states = T_x[0].shape[0]
435	n_symbols      = len(alphabet)
436	
437	T_flat_device = xp.asarray(
438		np.stack(T_x).transpose(1, 0, 2).reshape(n_input_states, n_symbols * n_input_states)
439	)
440 
441	# -- Sparse Belief Vectors -------------------------------------------
442
443	sparse_belief_vector_indices: list[np.ndarray] = []
444	sparse_belief_vector_values: list[np.ndarray] = []
445 
446	def _to_sparse(v: np.ndarray):
447
448		non_zero = np.where( np.abs( v ) > 1e-20 )[ 0 ]
449		return non_zero.astype(np.int32), v[ non_zero ]
450 
451	def _to_dense( sparse_indices: np.ndarray, sparse_values: np.ndarray ):
452		dense = np.zeros( n_input_states )
453		dense[ sparse_indices ] = sparse_values
454		return dense
455
456	def _batch_to_dense( indicies: list[int] ) -> np.ndarray:
457		
458		out = np.zeros((len(indicies), n_input_states), dtype=np.float64)
459		for row, i in enumerate(indicies):
460			out[row, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i]
461		
462		return out
463 
464	# -- Transitions ----------------------------------------------
465
466	t_origin: list[int]   = []
467	t_target: list[int]   = []
468	t_prob:   list[float] = []
469	t_symbol: list[int]   = []
470 
471	# -- Sparse dual bucket coarse hash ----------------------------------
472
473	# Since we use dual bins 
474	# Some x, y in p, q needs to differ by more than 1/(2*coarse_scale) to not share a bin
475	# Thus minimal total variation between non-bin sharing distributions is 
476	# To ensure jsd(p,q) < jsd_eps -> p,q share a bin, we use Pinsker's inequality
477	# 1 /(2*coarse_scale) <= 2*sqrt(2*jsd_eps)
478	# 1 /( 4*sqrt(2*jsd_eps) ) <= coarse_scale
479	
480	coarse_scale = 1 / ( 4*np.sqrt( 2 * jsd_eps ) )
481 
482	def _coarse_keys( sparse_indices: np.ndarray, sparse_values: np.ndarray) -> list[bytes]:
483		scaled = _to_dense( sparse_indices, sparse_values*coarse_scale )
484		return [ 
485			np.round(scaled).astype(np.int32).tobytes(), 
486			np.round(scaled + 0.5).astype(np.int32).tobytes() 
487		]
488
489	buckets: dict[bytes, list[int]] = {}
490 
491	# Jenson Shannon Divergence for sparse vectors
492	def _jsd_sparse(
493		ai: np.ndarray, av: np.ndarray,
494		bi: np.ndarray, bv: np.ndarray,
495	) -> float:
496
497		# get sorted union of the [non-zero] indices over both vectors
498		all_i = np.union1d( ai, bi )
499	
500		# form dense vectors
501		p = np.zeros( len( all_i ), dtype=np.float64 )
502		q = np.zeros( len( all_i ), dtype=np.float64 )
503
504		# np.searchsorted( all_i, ai ) finds position in all_i where each ai is 
505		p[ np.searchsorted( all_i, ai ) ] = av
506		q[ np.searchsorted( all_i, bi ) ] = bv
507
508		# Jenson Shannon divergence (JDS)
509		return af_jensenshannondivergence_cpu( p, q )
510
511	# Check if an existing belief vector is within jsd_eps 
512	def _lookup( sparse_indices: np.ndarray, sparse_values: np.ndarray ) -> int:
513		
514		keys = _coarse_keys( sparse_indices, sparse_values )
515		candidates = list( { idx for key in keys for idx in buckets.get( key, [] ) } )
516
517		if not candidates :
518			return -1
519
520		dense = _to_dense( sparse_indices, sparse_values ).reshape( 1, n_input_states )
521		candidate_vectors = _batch_to_dense( candidates )
522
523		distances = af_jensenshannondivergence_cpu( candidate_vectors, dense, axis=1 )
524		min_idx  = np.argmin( distances )
525		min_dist = distances[ min_idx ] 
526
527		if min_dist < jsd_eps :
528			return candidates[ min_idx ]
529		
530		return -1
531
532	# Map belief vector to course hash buckets for subsequent fast lookup
533	def _register( sparse_indices: np.ndarray, sparse_values: np.ndarray, idx: int):
534		for key in _coarse_keys( sparse_indices, sparse_values ):
535			buckets.setdefault( key, [] ).append( idx ) 
536
537	pi_indicies, pi_values = _to_sparse( pi )
538	sparse_belief_vector_indices.append( pi_indicies )
539	sparse_belief_vector_values.append( pi_values )
540	n_belief_vectors = 1
541
542	_register( pi_indicies, pi_values, 0 )
543
544	frontier = deque([0])
545
546	FRONTIER_CHUNK = 256
547 
548	try:
549		while frontier and n_belief_vectors < exact_state_cap:
550 
551			batch_indices = []
552			while frontier and len(batch_indices) < FRONTIER_CHUNK:
553				batch_indices.append(frontier.popleft())
554 
555			B = len(batch_indices)
556 
557			mus_device       = xp.asarray(_batch_to_dense( batch_indices ) )
558			flat_device      = mus_device @ T_flat_device
559			all_mu_Tx_device = flat_device.reshape( B, n_symbols, n_input_states )
560			probs_device     = all_mu_Tx_device.sum(axis=2)
561 
562			vb_device, vx_device = xp.where(probs_device > EPS)
563 
564			if vb_device.size == 0:
565				del mus_device, flat_device, all_mu_Tx_device, probs_device, vb_device, vx_device
566				continue
567 
568			vp_device      = probs_device[vb_device, vx_device]
569			mu_next_device = all_mu_Tx_device[vb_device, vx_device] / vp_device[:, None]
570 
571			mu_next_cpu = to_cpu( mu_next_device )
572
573			vp_cpu      = to_cpu( vp_device      )
574			vb_cpu      = to_cpu( vb_device      )
575			vx_cpu      = to_cpu( vx_device      )
576 
577			# Free GPU arrays for this batch immediately after download
578			del mus_device, flat_device, all_mu_Tx_device, probs_device
579			del vb_device, vx_device, vp_device, mu_next_device
580 
581			if not np.isfinite(mu_next_cpu).all():
582				raise ValueError("Non-finite belief vectors in BFS batch")
583 
584			for q in range( len( vb_cpu ) ):
585				
586				indicies, values = _to_sparse(mu_next_cpu[q])
587				orig = int( batch_indices[ int( vb_cpu[ q ] ) ] )
588				idx = _lookup( indicies, values )
589 
590				if idx == -1:
591					
592					idx = n_belief_vectors
593					sparse_belief_vector_indices.append( indicies )
594					sparse_belief_vector_values.append( values )
595					_register( indicies, values, idx )
596
597					n_belief_vectors += 1
598					frontier.append(idx)
599 
600				t_origin.append(orig)
601				t_target.append(idx)
602				t_prob.append(float(vp_cpu[q]))
603				t_symbol.append(int(vx_cpu[q]))
604
605		# -- Close the graph by mapping each dangling state to the nearest closed state
606
607		if frontier:
608 
609			if verbose:
610				print(f"Max exact states ({exact_state_cap}) exceded ({n_belief_vectors}). Closing graph...")
611 
612			dense_all = np.zeros((n_belief_vectors, n_input_states), dtype=np.float32)
613
614			for i in range(n_belief_vectors):
615				dense_all[i, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i].astype(np.float32)
616	
617			ixp = xp if CAGRA_AVAILABLE else np
618
619			if CAGRA_AVAILABLE:
620
621				bv_device_f32 = xp.asarray(dense_all)
622				del dense_all
623
624				if verbose : 
625					print("Building CAGRA index...")
626
627				index = cagra.build(
628					cagra.IndexParams(graph_degree=32, metric='sqeuclidean'),
629					bv_device_f32,
630				)
631
632			else :
633
634				if verbose : 
635					print("Building hnswlib index...")
636
637				T_flat_device = to_cpu( T_flat_device )
638				bv_device_f32 = dense_all
639				index = hnswlib.Index(space='l2', dim=bv_device_f32[0].size)
640				index.init_index(max_elements=len(bv_device_f32), ef_construction=500, M=50)
641				index.set_ef(50)
642				index.add_items(bv_device_f32)
643
644			if verbose:
645				print("Index built.")
646 
647			closure_buckets: dict[bytes, list[tuple]] = {}
648			max_error = 0.0
649			avg_error = 0.0
650			n_mapped  = 0
651 
652			def _closure_lookup( indicies: np.ndarray, values: np.ndarray) -> int:
653				seen = set()
654				for key in _coarse_keys( indicies, values):
655					for entry in closure_buckets.get(key, []):
656						eid = id(entry)
657						if eid not in seen:
658							seen.add(eid)
659							s_idx, s_val, target = entry
660							if _jsd_sparse(indicies, values, s_idx, s_val) < jsd_eps:
661								return target
662				return -1
663 
664			def _closure_register(indicies: np.ndarray, values: np.ndarray, target: int):
665				entry = (indicies, values, target)
666				for key in _coarse_keys(indicies, values):
667					closure_buckets.setdefault(key, []).append(entry)
668 
669			CHUNK = 10_000
670			frontier_list = list(frontier)
671
672			for start in range(0, len(frontier_list), CHUNK):
673
674				chunk_idxs = frontier_list[start : start + CHUNK]
675				C  = len(chunk_idxs)
676
677				if CAGRA_AVAILABLE :
678					chunk_device  = cp.asarray(_batch_to_dense(chunk_idxs))
679				else :
680					chunk_device  = _batch_to_dense(chunk_idxs)
681
682				flat_device   = chunk_device @ T_flat_device
683
684				mu_Tx_device  = flat_device.reshape(C, n_symbols, n_input_states)
685
686				probs_device  = mu_Tx_device.sum(axis=2)
687 
688				# Free intermediates we no longer need
689				del chunk_device, flat_device
690
691				vf_device, vx_device = ixp.where(probs_device > EPS)
692 
693				if vf_device.size == 0:
694					del mu_Tx_device, probs_device, vf_device, vx_device
695					continue
696 
697				vp_device   = probs_device[vf_device, vx_device]
698				mu_v_device = mu_Tx_device[vf_device, vx_device] / vp_device[:, None]
699 
700				# Free now-consumed GPU arrays
701				del mu_Tx_device, probs_device
702
703				mu_v_cpu = to_cpu( mu_v_device )
704				vp_cpu   = to_cpu( vp_device )
705				vf_cpu   = to_cpu( vf_device )
706				vx_cpu   = to_cpu( vx_device )
707
708				del vf_device, vx_device, vp_device
709 
710				# float32 copy for CAGRA; free float64 immediately after cast
711				unknown_device = mu_v_device.astype(ixp.float32)
712				del mu_v_device
713
714				N            = len(vf_cpu)
715				orig_arr     = np.array([chunk_idxs[f] for f in vf_cpu], dtype=np.int32)
716				next_idx     = np.full(N, -1, dtype=np.int32)
717				unknown_mask = np.ones(N, dtype=bool)
718
719				sparse_cache = [_to_sparse(mu_v_cpu[q]) for q in range(N)]
720
721				for q in range(N):
722					indicies, values = sparse_cache[q]
723					target = _closure_lookup(indicies, values)
724					if target != -1:
725						next_idx[q]     = target
726						unknown_mask[q] = False
727 
728				if unknown_mask.any():
729
730					unknown_rows = unknown_device[ixp.asarray(unknown_mask)]
731 
732					if CAGRA_AVAILABLE :
733
734						if verbose:
735							print(f"CARGA search for {unknown_mask.sum()} unknown states...")
736						
737						_, indices_device = cagra.search(
738							cagra.SearchParams(itopk_size=128),
739							index, unknown_rows, k_ann,
740						)
741
742					else :
743						if verbose :
744							print(f"hnswlib search for {unknown_mask.sum()} unknown states...")
745						
746						indices_device, _ = index.knn_query( to_cpu( unknown_rows ), k=k_ann )
747
748					output_indices = ixp.asarray( indices_device )
749 
750					min_dist = ixp.full(unknown_rows.shape[0], ixp.inf, dtype=ixp.float32)
751					min_ks   = ixp.zeros(unknown_rows.shape[0], dtype=ixp.int32)
752
753					for idx_k in range(output_indices.shape[1]):
754						sv   = bv_device_f32[output_indices[:, idx_k], :]
755
756						if CAGRA_AVAILABLE :
757							d = af_jensenshannondivergence_device(unknown_rows, sv, axis=1)
758						else :
759							d = af_jensenshannondivergence_cpu(unknown_rows, sv, axis=1)
760
761						mask = d < min_dist
762						min_dist = ixp.where(mask, d, min_dist)
763						min_ks   = ixp.where(mask, idx_k, min_ks)
764
765					max_dist  = to_cpu( ixp.max(min_dist) )
766					mean_dist = to_cpu( ixp.mean(min_dist) )
767 
768					n_mapped  += N
769					avg_error += mean_dist
770					max_error  = max(max_dist, max_error)
771
772					resolved   = to_cpu( output_indices[ixp.arange(unknown_rows.shape[0]), min_ks] )
773					unknown_qs = np.where(unknown_mask)[0]
774 
775					for i, q in enumerate(unknown_qs):
776						nz_idx, nz_val = sparse_cache[q]
777						next_idx[q]    = int(resolved[i])
778						_closure_register(nz_idx, nz_val, int(resolved[i]))
779 
780					# Free ANN-related GPU arrays
781					del unknown_rows, indices_device, output_indices
782					del min_dist, min_ks
783 
784				del unknown_device
785 
786				for q in range(N):
787					t_origin.append(int(orig_arr[q]))
788					t_target.append(int(next_idx[q]))
789					t_prob.append(float(vp_cpu[q]))
790					t_symbol.append(int(vx_cpu[q]))
791 
792			# -- Free CAGRA index and float32 belief matrix --------------------
793
794			# CAGRA memory lives outside CuPy's pool and is not released by free_all_blocks()
795			del index
796			del bv_device_f32
797 
798			if CUPY_AVAILABLE :
799
800				cp.get_default_memory_pool().free_all_blocks()
801				cp.get_default_pinned_memory_pool().free_all_blocks()
802
803			gc.collect()
804 
805			if verbose:
806				print(f"Closed: {n_mapped} dangling states")
807				print(f"Avg JSD state closure error: {(1.0 / np.log(2)) * avg_error / n_mapped} bits")
808				print(f"Max JSD state closure error: {(1.0 / np.log(2)) * max_error} bits")
809 
810		# ------------------ Create the MSP --------------------#
811		
812		msp_transitions = [
813			Transition(
814				origin_state_idx = t_origin[i],
815				target_state_idx = t_target[i],
816				prob             = t_prob[i],
817				symbol_idx       = t_symbol[i],
818			)
819			for i in range(len(t_origin))
820		]
821 
822		mixed_state_vectors = list(_batch_to_dense(list(range(n_belief_vectors))))
823 
824		if verbose:
825			print("Assembled msp_transitions.")
826 
827		msp = MSP(
828			states              = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)],
829			mixed_state_vectors = mixed_state_vectors,
830			transitions         = msp_transitions,
831			alphabet            = alphabet.copy(),
832		)
833 
834		if verbose:
835			print("\nDone.\n")
836 
837		return msp
838 
839	finally:
840
841		# ------------------ Cleanup --------------------#
842
843		if 'T_flat_device' in dir():
844			del T_flat_device
845
846		if 'index' in locals():
847			del index
848		
849		if 'bv_device_f32' in locals():
850			del bv_device_f32
851
852		if CUPY_AVAILABLE :
853
854			cp.get_default_memory_pool().free_all_blocks()
855			cp.get_default_pinned_memory_pool().free_all_blocks()
856
857		gc.collect()
def to_cpu(array):
33def to_cpu( array ):
34    return array.get() if CUPY_AVAILABLE and hasattr( array, 'get' ) else array
class MSP:
 55class MSP:
 56	"""
 57	Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity.
 58 
 59	Computes E (excess entropy),  S (synchronization information), and T (transient information) 
 60	using spectral decomposition of the fundamental matrix.
 61 
 62	See
 63		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 64		https://arxiv.org/abs/1309.3792
 65	"""
 66
 67	def __init__(
 68		self,
 69		states,
 70		mixed_state_vectors,
 71		transitions,
 72		alphabet,
 73		start_state_idx: int   = 0,
 74		gmres_rtol:      float = 1e-10,
 75		gmres_maxiter:   int   = 1000,
 76		eps:             float = 1e-12,
 77	):
 78		self.states              = states
 79		self.mixed_state_vectors = np.array(mixed_state_vectors)
 80		self.transitions         = transitions
 81		self.alphabet            = alphabet
 82		self.start_state_idx     = start_state_idx
 83		self.gmres_rtol          = gmres_rtol
 84		self.gmres_maxiter       = gmres_maxiter
 85		self.EPS                 = eps
 86		self._cache              = {}
 87 
 88		M = len(states)
 89		if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M:
 90			raise ValueError(
 91				f"mixed_state_vectors must have shape (M={M}, N_causal), "
 92				f"got {self.mixed_state_vectors.shape}."
 93			)
 94		if not (0 <= start_state_idx < M):
 95			raise ValueError(
 96				f"start_state_idx={start_state_idx} is out of range for M={M} states."
 97			)
 98		if gmres_rtol <= 0:
 99			raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.")
100		if eps <= 0:
101			raise ValueError(f"eps must be positive, got {eps}.")
102 
103	def _build_W(self) -> sp_sparse.csr_matrix:
104		"""
105		Build and return the sparse row-stochastic transition matrix W (M × M).
106		Rows are re-normalised to absorb floating-point drift.
107		"""
108		if "W" in self._cache:
109			return self._cache["W"]
110 
111		M    = len(self.states)
112		
113		rows = [tr.origin_state_idx for tr in self.transitions]
114		cols = [tr.target_state_idx for tr in self.transitions]
115		data = [tr.prob              for tr in self.transitions]
116		
117		W    = sp_sparse.csr_matrix((data, (rows, cols)), shape=(M, M))
118 
119		row_sums = np.array(W.sum(axis=1)).ravel()
120		
121		if np.any(row_sums < self.EPS):
122			bad = np.where(row_sums < self.EPS)[0].tolist()
123			raise ValueError(
124				f"States at indices {bad} have zero total outgoing probability. "
125				"Check that every MSP state has at least one outgoing transition."
126			)
127		
128		W = (sp_sparse.diags(1.0 / row_sums) @ W).tocsr()
129 
130		self._cache["W"] = W
131		return W
132 
133	def _compute_pi_W(self, W) -> np.ndarray:
134		"""
135		Solve for the stationary distribution of W via GPU GMRES, then free
136		all intermediate GPU arrays before returning the CPU result.
137		"""
138		if "pi_W" in self._cache:
139			return self._cache["pi_W"]
140 
141		xp = cp if CUPY_AVAILABLE else np
142		
143		xp_sparse = cpx_sparse if CUPY_AVAILABLE else sp_sparse
144		xp_linalg = cpx_linalg if CUPY_AVAILABLE else sp_linalg
145
146		W_device = cpx_sparse.csr_matrix(W)
147		M     = W_device.shape[0]
148 
149		I_device  = xp_sparse.eye(M, format="csr", dtype=W_device.dtype)
150		A_device  = (W_device.T - I_device).tocsr()
151 
152		ones_row = xp_sparse.csr_matrix(xp.ones((1, M), dtype=W_device.dtype))
153		A_aug    = xp_sparse.vstack([A_device[:-1, :], ones_row]).tocsr()
154 
155		b_device = xp.zeros(M, dtype=W_device.dtype)
156		b_device[-1] = 1.0
157 
158		diag_vals = A_aug.diagonal()
159		diag_vals[diag_vals == 0] = 1.0
160		M_inv = xp_sparse.diags(1.0 / diag_vals)
161 
162		pi_W_device, info = xp_linalg.gmres(
163			A_aug, b_device,
164			M=M_inv,
165			restart=100,
166			maxiter=1000,
167		)
168 
169		if info > 0:
170			print(f"Warning: GMRES did not converge after {info} iterations")
171		elif info < 0:
172			raise ValueError("GMRES failed due to illegal input or breakdown.")
173 
174		pi_W_device = xp.clip(pi_W_device, 0.0, None)
175		s = pi_W_device.sum()
176		
177		if s < self.EPS:
178			raise ValueError("Stationary distribution collapsed.")
179		
180		pi_W_device /= s
181 
182		# Download before freeing
183		result = to_cpu( pi_W_device )
184 
185		# -- Free every GPU object created in this method ----------------------
186		del W_device, I_device, A_device, ones_row, A_aug, b_device, M_inv, diag_vals, pi_W_device
187
188		if CUPY_AVAILABLE :
189			cp.get_default_memory_pool().free_all_blocks()
190			cp.get_default_pinned_memory_pool().free_all_blocks()
191		
192		gc.collect()
193 
194		self._cache["pi_W"] = result
195
196		return result
197 
198	def _fundamental(
199		self,
200		W:     sp_sparse.csr_matrix,
201		pi_W:  np.ndarray,
202	) -> tuple[sp_linalg.LinearOperator, np.ndarray, np.ndarray, LinearOperator]:
203		"""
204		Compute the fundamental matrix
205		"""
206		if "fundamental" in self._cache:
207			return self._cache["fundamental"]
208 
209		M  = W.shape[0]
210		WT = W.T.tocsc()
211 
212		e0          = np.zeros(M)
213		e0[self.start_state_idx] = 1.0
214 
215		def Q_T_matvec(v: np.ndarray) -> np.ndarray:
216			return v - WT @ v + pi_W * v.sum()
217 
218		Q_LO = sp_linalg.LinearOperator((M, M), matvec=Q_T_matvec, dtype=float)
219 
220		W_diag  = np.array(W.diagonal())
221		M_diag  = 1.0 - W_diag + pi_W
222		M_diag  = np.where(np.abs(M_diag) > self.EPS, M_diag, 1.0)
223		precond = sp_linalg.LinearOperator((M, M), matvec=lambda v: v / M_diag, dtype=float)
224 
225		Z_row, exit_code = sp_linalg.lgmres(
226			Q_LO,
227			e0,
228			M=precond,
229			rtol=self.gmres_rtol,
230			atol=self.gmres_rtol,
231			maxiter=self.gmres_maxiter,
232		)
233		if exit_code != 0:
234			raise RuntimeError(
235				f"GMRES failed on the fundamental matrix solve (exit code {exit_code}). "
236				"Try increasing gmres_maxiter, loosening gmres_rtol, or verify that "
237				"the transition matrix is irreducible."
238			)
239 
240		fund_row = Z_row - pi_W
241 
242		self._cache["fundamental"] = (Q_LO, Z_row, fund_row, precond)
243		return Q_LO, Z_row, fund_row, precond
244 
245	def _get_H_WA(self) -> np.ndarray:
246		"""
247		Per-state Shannon entropy of the output symbol distribution,
248		marginalised over target states. Rows are normalised before use.
249		"""
250		if "H_WA" in self._cache:
251			return self._cache["H_WA"]
252 
253		M, A       = len(self.states), len(self.alphabet)
254		state_sym  = np.zeros((M, A))
255		for tr in self.transitions:
256			state_sym[tr.origin_state_idx, tr.symbol_idx] += tr.prob
257 
258		sym_row_sums = state_sym.sum(axis=1, keepdims=True)
259		sym_row_sums = np.where(sym_row_sums > self.EPS, sym_row_sums, 1.0)
260		state_sym   /= sym_row_sums
261 
262		safe     = state_sym > self.EPS
263		log_vals = np.where(safe, np.log2(np.where(safe, state_sym, 1.0)), 0.0)
264		H_WA     = -np.sum(state_sym * log_vals, axis=1)
265 
266		self._cache["H_WA"] = H_WA
267		return H_WA
268 
269	def _get_H_eta(self) -> np.ndarray:
270		"""
271		Per-state Shannon entropy of the mixed-state (belief) vector.
272		Rows are normalised before entropy computation.
273		"""
274		if "H_eta" in self._cache:
275			return self._cache["H_eta"]
276 
277		msv      = self.mixed_state_vectors.copy().astype(float)
278		row_sums = msv.sum(axis=1, keepdims=True)
279		row_sums = np.where(row_sums > self.EPS, row_sums, 1.0)
280		msv     /= row_sums
281 
282		safe     = msv > self.EPS
283		log_vals = np.where(safe, np.log2(np.where(safe, msv, 1.0)), 0.0)
284		H_eta    = -np.sum(msv * log_vals, axis=1)
285 
286		self._cache["H_eta"] = H_eta
287		return H_eta
288 
289	def get_E_S_T(self) -> tuple[float, float, float]:
290
291		try:
292			W    = self._build_W()
293			pi_W = self._compute_pi_W(W)   # GPU memory freed inside _compute_pi_W
294			H_WA  = self._get_H_WA()
295			H_eta = self._get_H_eta()
296 
297			Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W)
298 
299			E = float(fund_row @ H_WA)
300			S = float(fund_row @ H_eta)
301 
302			ZZ_row, exit_code = sp_linalg.lgmres(
303				Q_LO,
304				Z_row,
305				x0=Z_row,
306				M=precond,
307				rtol=self.gmres_rtol,
308				atol=self.gmres_rtol,
309				maxiter=self.gmres_maxiter,
310			)
311			if exit_code != 0:
312				raise RuntimeError(
313					f"GMRES failed on the Z² solve for T (exit code {exit_code}). "
314					"Try increasing gmres_maxiter or loosening gmres_rtol."
315				)
316 
317			t_row = ZZ_row - pi_W
318			T     = float(t_row @ H_WA)
319
320			return E, S, T
321 
322		finally:
323
324			if CUPY_AVAILABLE :
325				cp.get_default_memory_pool().free_all_blocks()
326				cp.get_default_pinned_memory_pool().free_all_blocks()
327
328			gc.collect()
329 
330	def clear_cache(self) -> None:
331		self._cache.clear()

Mixed-State Presentation (MSP) for computing 'exact' intrinsic complexity.

Computes E (excess entropy), S (synchronization information), and T (transient information) using spectral decomposition of the fundamental matrix.

See Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792

MSP( states, mixed_state_vectors, transitions, alphabet, start_state_idx: int = 0, gmres_rtol: float = 1e-10, gmres_maxiter: int = 1000, eps: float = 1e-12)
 67	def __init__(
 68		self,
 69		states,
 70		mixed_state_vectors,
 71		transitions,
 72		alphabet,
 73		start_state_idx: int   = 0,
 74		gmres_rtol:      float = 1e-10,
 75		gmres_maxiter:   int   = 1000,
 76		eps:             float = 1e-12,
 77	):
 78		self.states              = states
 79		self.mixed_state_vectors = np.array(mixed_state_vectors)
 80		self.transitions         = transitions
 81		self.alphabet            = alphabet
 82		self.start_state_idx     = start_state_idx
 83		self.gmres_rtol          = gmres_rtol
 84		self.gmres_maxiter       = gmres_maxiter
 85		self.EPS                 = eps
 86		self._cache              = {}
 87 
 88		M = len(states)
 89		if self.mixed_state_vectors.ndim != 2 or self.mixed_state_vectors.shape[0] != M:
 90			raise ValueError(
 91				f"mixed_state_vectors must have shape (M={M}, N_causal), "
 92				f"got {self.mixed_state_vectors.shape}."
 93			)
 94		if not (0 <= start_state_idx < M):
 95			raise ValueError(
 96				f"start_state_idx={start_state_idx} is out of range for M={M} states."
 97			)
 98		if gmres_rtol <= 0:
 99			raise ValueError(f"gmres_rtol must be positive, got {gmres_rtol}.")
100		if eps <= 0:
101			raise ValueError(f"eps must be positive, got {eps}.")
states
mixed_state_vectors
transitions
alphabet
start_state_idx
gmres_rtol
gmres_maxiter
EPS
def get_E_S_T(self) -> tuple[float, float, float]:
289	def get_E_S_T(self) -> tuple[float, float, float]:
290
291		try:
292			W    = self._build_W()
293			pi_W = self._compute_pi_W(W)   # GPU memory freed inside _compute_pi_W
294			H_WA  = self._get_H_WA()
295			H_eta = self._get_H_eta()
296 
297			Q_LO, Z_row, fund_row, precond = self._fundamental(W, pi_W)
298 
299			E = float(fund_row @ H_WA)
300			S = float(fund_row @ H_eta)
301 
302			ZZ_row, exit_code = sp_linalg.lgmres(
303				Q_LO,
304				Z_row,
305				x0=Z_row,
306				M=precond,
307				rtol=self.gmres_rtol,
308				atol=self.gmres_rtol,
309				maxiter=self.gmres_maxiter,
310			)
311			if exit_code != 0:
312				raise RuntimeError(
313					f"GMRES failed on the Z² solve for T (exit code {exit_code}). "
314					"Try increasing gmres_maxiter or loosening gmres_rtol."
315				)
316 
317			t_row = ZZ_row - pi_W
318			T     = float(t_row @ H_WA)
319
320			return E, S, T
321 
322		finally:
323
324			if CUPY_AVAILABLE :
325				cp.get_default_memory_pool().free_all_blocks()
326				cp.get_default_pinned_memory_pool().free_all_blocks()
327
328			gc.collect()
def clear_cache(self) -> None:
330	def clear_cache(self) -> None:
331		self._cache.clear()
def compute_msp_exact( T_x: list[list[list[fractions.Fraction]]], pi: tuple[fractions.Fraction], n_states: int, alphabet: list[str], exact_state_cap: int = 1000, verbose: bool = True):
334def compute_msp_exact(
335	T_x : list[ list[ list[ Fraction ] ] ],
336	pi  : tuple[Fraction],
337	n_states : int,
338	alphabet : list[str],
339	exact_state_cap: int = 1000,
340	verbose: bool = True,
341):
342
343	n_symbols  = len(alphabet)
344
345	def _apply(mu: tuple[Fraction, ...], x: int) :
346
347		Tx = T_x[x]
348		mu_Tx = [Fraction(0)] * n_states
349		for i, w in enumerate(mu):
350			if w == Fraction(0):
351				continue
352			row = Tx[i]
353			for j in range(n_states):
354				if row[j]:
355					mu_Tx[j] += w * row[j]
356
357		prob = sum(mu_Tx)
358		if prob == Fraction(0):
359			return None, Fraction(0)
360
361		mu_next = tuple(v / prob for v in mu_Tx)
362		return mu_next, prob
363
364	belief_vectors: list[tuple[Fraction, ...]] = []
365	seen_states:    dict[tuple, int]           = {}
366	msp_transitions: list                      = []
367
368	def _register(mu: tuple[Fraction, ...]) -> int:
369		idx = len(belief_vectors)
370		belief_vectors.append(mu)
371		seen_states[mu] = idx
372		return idx
373
374	start_mu = tuple(pi)
375	_register(start_mu)
376	frontier = deque([0])
377
378	while frontier and len(belief_vectors) < exact_state_cap:
379		
380		idx = frontier.popleft()
381		mu  = belief_vectors[idx]
382
383		for x in range(n_symbols):
384
385			mu_next, prob = _apply(mu, x)
386
387			if mu_next is None:
388				continue
389
390			if mu_next not in seen_states:
391				next_idx = _register(mu_next)
392				frontier.append(next_idx)
393			else:
394				next_idx = seen_states[mu_next]
395
396			msp_transitions.append(
397				Transition(
398					origin_state_idx = idx,
399					target_state_idx = next_idx,
400					prob             = float(prob),
401					symbol_idx       = x,
402					pq               = prob
403				)
404			)
405
406	if frontier:
407		raise RuntimeError( f"Exact state cap {exact_state_cap} exceded." )
408
409	if verbose:
410		print(f"Exact MSP found: {len(belief_vectors)} states.")
411
412	msp = MSP(
413		states              = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)],
414		mixed_state_vectors = mixed_state_vectors,
415		transitions         = msp_transitions,
416		alphabet            = alphabet.copy(),
417	)
418
419	return msp
def compute_msp( T_x: list[numpy.ndarray], pi: numpy.ndarray, n_states: int, alphabet: list[str], exact_state_cap: int = 175000, jsd_eps: float = 1e-07, k_ann: int = 50, EPS: float = 1e-12, verbose=True):
421def compute_msp(
422	T_x : list[np.ndarray],
423	pi : np.ndarray,
424	n_states : int,
425	alphabet : list[str],
426	exact_state_cap: int = 175_000,
427	jsd_eps: float = 1e-7,
428	k_ann: int   = 50,
429	EPS : float = 1e-12,
430	verbose = True,
431) :
432
433	xp = cp if CUPY_AVAILABLE else np
434
435	n_input_states = T_x[0].shape[0]
436	n_symbols      = len(alphabet)
437	
438	T_flat_device = xp.asarray(
439		np.stack(T_x).transpose(1, 0, 2).reshape(n_input_states, n_symbols * n_input_states)
440	)
441 
442	# -- Sparse Belief Vectors -------------------------------------------
443
444	sparse_belief_vector_indices: list[np.ndarray] = []
445	sparse_belief_vector_values: list[np.ndarray] = []
446 
447	def _to_sparse(v: np.ndarray):
448
449		non_zero = np.where( np.abs( v ) > 1e-20 )[ 0 ]
450		return non_zero.astype(np.int32), v[ non_zero ]
451 
452	def _to_dense( sparse_indices: np.ndarray, sparse_values: np.ndarray ):
453		dense = np.zeros( n_input_states )
454		dense[ sparse_indices ] = sparse_values
455		return dense
456
457	def _batch_to_dense( indicies: list[int] ) -> np.ndarray:
458		
459		out = np.zeros((len(indicies), n_input_states), dtype=np.float64)
460		for row, i in enumerate(indicies):
461			out[row, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i]
462		
463		return out
464 
465	# -- Transitions ----------------------------------------------
466
467	t_origin: list[int]   = []
468	t_target: list[int]   = []
469	t_prob:   list[float] = []
470	t_symbol: list[int]   = []
471 
472	# -- Sparse dual bucket coarse hash ----------------------------------
473
474	# Since we use dual bins 
475	# Some x, y in p, q needs to differ by more than 1/(2*coarse_scale) to not share a bin
476	# Thus minimal total variation between non-bin sharing distributions is 
477	# To ensure jsd(p,q) < jsd_eps -> p,q share a bin, we use Pinsker's inequality
478	# 1 /(2*coarse_scale) <= 2*sqrt(2*jsd_eps)
479	# 1 /( 4*sqrt(2*jsd_eps) ) <= coarse_scale
480	
481	coarse_scale = 1 / ( 4*np.sqrt( 2 * jsd_eps ) )
482 
483	def _coarse_keys( sparse_indices: np.ndarray, sparse_values: np.ndarray) -> list[bytes]:
484		scaled = _to_dense( sparse_indices, sparse_values*coarse_scale )
485		return [ 
486			np.round(scaled).astype(np.int32).tobytes(), 
487			np.round(scaled + 0.5).astype(np.int32).tobytes() 
488		]
489
490	buckets: dict[bytes, list[int]] = {}
491 
492	# Jenson Shannon Divergence for sparse vectors
493	def _jsd_sparse(
494		ai: np.ndarray, av: np.ndarray,
495		bi: np.ndarray, bv: np.ndarray,
496	) -> float:
497
498		# get sorted union of the [non-zero] indices over both vectors
499		all_i = np.union1d( ai, bi )
500	
501		# form dense vectors
502		p = np.zeros( len( all_i ), dtype=np.float64 )
503		q = np.zeros( len( all_i ), dtype=np.float64 )
504
505		# np.searchsorted( all_i, ai ) finds position in all_i where each ai is 
506		p[ np.searchsorted( all_i, ai ) ] = av
507		q[ np.searchsorted( all_i, bi ) ] = bv
508
509		# Jenson Shannon divergence (JDS)
510		return af_jensenshannondivergence_cpu( p, q )
511
512	# Check if an existing belief vector is within jsd_eps 
513	def _lookup( sparse_indices: np.ndarray, sparse_values: np.ndarray ) -> int:
514		
515		keys = _coarse_keys( sparse_indices, sparse_values )
516		candidates = list( { idx for key in keys for idx in buckets.get( key, [] ) } )
517
518		if not candidates :
519			return -1
520
521		dense = _to_dense( sparse_indices, sparse_values ).reshape( 1, n_input_states )
522		candidate_vectors = _batch_to_dense( candidates )
523
524		distances = af_jensenshannondivergence_cpu( candidate_vectors, dense, axis=1 )
525		min_idx  = np.argmin( distances )
526		min_dist = distances[ min_idx ] 
527
528		if min_dist < jsd_eps :
529			return candidates[ min_idx ]
530		
531		return -1
532
533	# Map belief vector to course hash buckets for subsequent fast lookup
534	def _register( sparse_indices: np.ndarray, sparse_values: np.ndarray, idx: int):
535		for key in _coarse_keys( sparse_indices, sparse_values ):
536			buckets.setdefault( key, [] ).append( idx ) 
537
538	pi_indicies, pi_values = _to_sparse( pi )
539	sparse_belief_vector_indices.append( pi_indicies )
540	sparse_belief_vector_values.append( pi_values )
541	n_belief_vectors = 1
542
543	_register( pi_indicies, pi_values, 0 )
544
545	frontier = deque([0])
546
547	FRONTIER_CHUNK = 256
548 
549	try:
550		while frontier and n_belief_vectors < exact_state_cap:
551 
552			batch_indices = []
553			while frontier and len(batch_indices) < FRONTIER_CHUNK:
554				batch_indices.append(frontier.popleft())
555 
556			B = len(batch_indices)
557 
558			mus_device       = xp.asarray(_batch_to_dense( batch_indices ) )
559			flat_device      = mus_device @ T_flat_device
560			all_mu_Tx_device = flat_device.reshape( B, n_symbols, n_input_states )
561			probs_device     = all_mu_Tx_device.sum(axis=2)
562 
563			vb_device, vx_device = xp.where(probs_device > EPS)
564 
565			if vb_device.size == 0:
566				del mus_device, flat_device, all_mu_Tx_device, probs_device, vb_device, vx_device
567				continue
568 
569			vp_device      = probs_device[vb_device, vx_device]
570			mu_next_device = all_mu_Tx_device[vb_device, vx_device] / vp_device[:, None]
571 
572			mu_next_cpu = to_cpu( mu_next_device )
573
574			vp_cpu      = to_cpu( vp_device      )
575			vb_cpu      = to_cpu( vb_device      )
576			vx_cpu      = to_cpu( vx_device      )
577 
578			# Free GPU arrays for this batch immediately after download
579			del mus_device, flat_device, all_mu_Tx_device, probs_device
580			del vb_device, vx_device, vp_device, mu_next_device
581 
582			if not np.isfinite(mu_next_cpu).all():
583				raise ValueError("Non-finite belief vectors in BFS batch")
584 
585			for q in range( len( vb_cpu ) ):
586				
587				indicies, values = _to_sparse(mu_next_cpu[q])
588				orig = int( batch_indices[ int( vb_cpu[ q ] ) ] )
589				idx = _lookup( indicies, values )
590 
591				if idx == -1:
592					
593					idx = n_belief_vectors
594					sparse_belief_vector_indices.append( indicies )
595					sparse_belief_vector_values.append( values )
596					_register( indicies, values, idx )
597
598					n_belief_vectors += 1
599					frontier.append(idx)
600 
601				t_origin.append(orig)
602				t_target.append(idx)
603				t_prob.append(float(vp_cpu[q]))
604				t_symbol.append(int(vx_cpu[q]))
605
606		# -- Close the graph by mapping each dangling state to the nearest closed state
607
608		if frontier:
609 
610			if verbose:
611				print(f"Max exact states ({exact_state_cap}) exceded ({n_belief_vectors}). Closing graph...")
612 
613			dense_all = np.zeros((n_belief_vectors, n_input_states), dtype=np.float32)
614
615			for i in range(n_belief_vectors):
616				dense_all[i, sparse_belief_vector_indices[i]] = sparse_belief_vector_values[i].astype(np.float32)
617	
618			ixp = xp if CAGRA_AVAILABLE else np
619
620			if CAGRA_AVAILABLE:
621
622				bv_device_f32 = xp.asarray(dense_all)
623				del dense_all
624
625				if verbose : 
626					print("Building CAGRA index...")
627
628				index = cagra.build(
629					cagra.IndexParams(graph_degree=32, metric='sqeuclidean'),
630					bv_device_f32,
631				)
632
633			else :
634
635				if verbose : 
636					print("Building hnswlib index...")
637
638				T_flat_device = to_cpu( T_flat_device )
639				bv_device_f32 = dense_all
640				index = hnswlib.Index(space='l2', dim=bv_device_f32[0].size)
641				index.init_index(max_elements=len(bv_device_f32), ef_construction=500, M=50)
642				index.set_ef(50)
643				index.add_items(bv_device_f32)
644
645			if verbose:
646				print("Index built.")
647 
648			closure_buckets: dict[bytes, list[tuple]] = {}
649			max_error = 0.0
650			avg_error = 0.0
651			n_mapped  = 0
652 
653			def _closure_lookup( indicies: np.ndarray, values: np.ndarray) -> int:
654				seen = set()
655				for key in _coarse_keys( indicies, values):
656					for entry in closure_buckets.get(key, []):
657						eid = id(entry)
658						if eid not in seen:
659							seen.add(eid)
660							s_idx, s_val, target = entry
661							if _jsd_sparse(indicies, values, s_idx, s_val) < jsd_eps:
662								return target
663				return -1
664 
665			def _closure_register(indicies: np.ndarray, values: np.ndarray, target: int):
666				entry = (indicies, values, target)
667				for key in _coarse_keys(indicies, values):
668					closure_buckets.setdefault(key, []).append(entry)
669 
670			CHUNK = 10_000
671			frontier_list = list(frontier)
672
673			for start in range(0, len(frontier_list), CHUNK):
674
675				chunk_idxs = frontier_list[start : start + CHUNK]
676				C  = len(chunk_idxs)
677
678				if CAGRA_AVAILABLE :
679					chunk_device  = cp.asarray(_batch_to_dense(chunk_idxs))
680				else :
681					chunk_device  = _batch_to_dense(chunk_idxs)
682
683				flat_device   = chunk_device @ T_flat_device
684
685				mu_Tx_device  = flat_device.reshape(C, n_symbols, n_input_states)
686
687				probs_device  = mu_Tx_device.sum(axis=2)
688 
689				# Free intermediates we no longer need
690				del chunk_device, flat_device
691
692				vf_device, vx_device = ixp.where(probs_device > EPS)
693 
694				if vf_device.size == 0:
695					del mu_Tx_device, probs_device, vf_device, vx_device
696					continue
697 
698				vp_device   = probs_device[vf_device, vx_device]
699				mu_v_device = mu_Tx_device[vf_device, vx_device] / vp_device[:, None]
700 
701				# Free now-consumed GPU arrays
702				del mu_Tx_device, probs_device
703
704				mu_v_cpu = to_cpu( mu_v_device )
705				vp_cpu   = to_cpu( vp_device )
706				vf_cpu   = to_cpu( vf_device )
707				vx_cpu   = to_cpu( vx_device )
708
709				del vf_device, vx_device, vp_device
710 
711				# float32 copy for CAGRA; free float64 immediately after cast
712				unknown_device = mu_v_device.astype(ixp.float32)
713				del mu_v_device
714
715				N            = len(vf_cpu)
716				orig_arr     = np.array([chunk_idxs[f] for f in vf_cpu], dtype=np.int32)
717				next_idx     = np.full(N, -1, dtype=np.int32)
718				unknown_mask = np.ones(N, dtype=bool)
719
720				sparse_cache = [_to_sparse(mu_v_cpu[q]) for q in range(N)]
721
722				for q in range(N):
723					indicies, values = sparse_cache[q]
724					target = _closure_lookup(indicies, values)
725					if target != -1:
726						next_idx[q]     = target
727						unknown_mask[q] = False
728 
729				if unknown_mask.any():
730
731					unknown_rows = unknown_device[ixp.asarray(unknown_mask)]
732 
733					if CAGRA_AVAILABLE :
734
735						if verbose:
736							print(f"CARGA search for {unknown_mask.sum()} unknown states...")
737						
738						_, indices_device = cagra.search(
739							cagra.SearchParams(itopk_size=128),
740							index, unknown_rows, k_ann,
741						)
742
743					else :
744						if verbose :
745							print(f"hnswlib search for {unknown_mask.sum()} unknown states...")
746						
747						indices_device, _ = index.knn_query( to_cpu( unknown_rows ), k=k_ann )
748
749					output_indices = ixp.asarray( indices_device )
750 
751					min_dist = ixp.full(unknown_rows.shape[0], ixp.inf, dtype=ixp.float32)
752					min_ks   = ixp.zeros(unknown_rows.shape[0], dtype=ixp.int32)
753
754					for idx_k in range(output_indices.shape[1]):
755						sv   = bv_device_f32[output_indices[:, idx_k], :]
756
757						if CAGRA_AVAILABLE :
758							d = af_jensenshannondivergence_device(unknown_rows, sv, axis=1)
759						else :
760							d = af_jensenshannondivergence_cpu(unknown_rows, sv, axis=1)
761
762						mask = d < min_dist
763						min_dist = ixp.where(mask, d, min_dist)
764						min_ks   = ixp.where(mask, idx_k, min_ks)
765
766					max_dist  = to_cpu( ixp.max(min_dist) )
767					mean_dist = to_cpu( ixp.mean(min_dist) )
768 
769					n_mapped  += N
770					avg_error += mean_dist
771					max_error  = max(max_dist, max_error)
772
773					resolved   = to_cpu( output_indices[ixp.arange(unknown_rows.shape[0]), min_ks] )
774					unknown_qs = np.where(unknown_mask)[0]
775 
776					for i, q in enumerate(unknown_qs):
777						nz_idx, nz_val = sparse_cache[q]
778						next_idx[q]    = int(resolved[i])
779						_closure_register(nz_idx, nz_val, int(resolved[i]))
780 
781					# Free ANN-related GPU arrays
782					del unknown_rows, indices_device, output_indices
783					del min_dist, min_ks
784 
785				del unknown_device
786 
787				for q in range(N):
788					t_origin.append(int(orig_arr[q]))
789					t_target.append(int(next_idx[q]))
790					t_prob.append(float(vp_cpu[q]))
791					t_symbol.append(int(vx_cpu[q]))
792 
793			# -- Free CAGRA index and float32 belief matrix --------------------
794
795			# CAGRA memory lives outside CuPy's pool and is not released by free_all_blocks()
796			del index
797			del bv_device_f32
798 
799			if CUPY_AVAILABLE :
800
801				cp.get_default_memory_pool().free_all_blocks()
802				cp.get_default_pinned_memory_pool().free_all_blocks()
803
804			gc.collect()
805 
806			if verbose:
807				print(f"Closed: {n_mapped} dangling states")
808				print(f"Avg JSD state closure error: {(1.0 / np.log(2)) * avg_error / n_mapped} bits")
809				print(f"Max JSD state closure error: {(1.0 / np.log(2)) * max_error} bits")
810 
811		# ------------------ Create the MSP --------------------#
812		
813		msp_transitions = [
814			Transition(
815				origin_state_idx = t_origin[i],
816				target_state_idx = t_target[i],
817				prob             = t_prob[i],
818				symbol_idx       = t_symbol[i],
819			)
820			for i in range(len(t_origin))
821		]
822 
823		mixed_state_vectors = list(_batch_to_dense(list(range(n_belief_vectors))))
824 
825		if verbose:
826			print("Assembled msp_transitions.")
827 
828		msp = MSP(
829			states              = [CausalState(name=f"MS_{i}") for i in range(n_belief_vectors)],
830			mixed_state_vectors = mixed_state_vectors,
831			transitions         = msp_transitions,
832			alphabet            = alphabet.copy(),
833		)
834 
835		if verbose:
836			print("\nDone.\n")
837 
838		return msp
839 
840	finally:
841
842		# ------------------ Cleanup --------------------#
843
844		if 'T_flat_device' in dir():
845			del T_flat_device
846
847		if 'index' in locals():
848			del index
849		
850		if 'bv_device_f32' in locals():
851			del bv_device_f32
852
853		if CUPY_AVAILABLE :
854
855			cp.get_default_memory_pool().free_all_blocks()
856			cp.get_default_pinned_memory_pool().free_all_blocks()
857
858		gc.collect()