amachine.am_hmm

   1from __future__ import annotations
   2
   3from abc import abstractmethod
   4from typing import override
   5import gc
   6import copy
   7from collections import deque
   8from collections import defaultdict
   9import json
  10from fractions import Fraction
  11from pathlib import Path
  12import warnings
  13import time
  14from dataclasses import dataclass, asdict, field
  15
  16import networkx as nx
  17
  18from automata.fa.dfa import DFA
  19
  20import sympy
  21import numpy as np
  22
  23from .am_machine import Machine
  24from .am_symbol import Symbol
  25
  26from   .am_msp import MSP, compute_msp, compute_msp_exact
  27from   .am_solve import solve_for_pi, solve_for_pi_fractional
  28from . import am_vis
  29from . import am_fast
  30
  31from .am_causal_state import CausalState
  32from .am_transition   import Transition
  33
  34from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence
  35from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu
  36
  37class HMM(Machine) :
  38
  39	@staticmethod
  40	def isomorphic_shift(
  41		input_symbol_indices: np.ndarray,
  42		input_state_indices:  np.ndarray,
  43		m : HMM,
  44		shift : int = 1
  45	) -> dict[str, np.ndarray]:
  46
  47		if not any( state.isomorphs for state in m.states ):
  48			raise ValueError("HMM has no states with isomorphs")
  49
  50		inputs = np.asarray(input_symbol_indices)
  51		states = np.asarray(input_state_indices)
  52
  53		n_states = len(m.states)
  54
  55		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
  56
  57		for tr in m.transitions:
  58			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
  59
  60		# Build isomorph remapping: identity by default, overridden where isomorphs exist
  61		iso_table = np.arange(n_states, dtype=np.int32)
  62
  63		for i, state in enumerate(m.states):
  64			if state.isomorphs:
  65				isomorph       = sorted(state.isomorphs)[0]
  66				iso_table[ i ] = m.state_idx_map[isomorph]
  67
  68		for i, state in enumerate(m.states):
  69			if state.isomorphs:
  70
  71				# extend the isomorph list to include the identity
  72				isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] )
  73
  74				# find the identity index
  75				pos = isormorphs_with_identity.index( i )
  76
  77				# cyclical shift 
  78				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
  79
  80		origins = states[:-1]
  81		targets = states[1:]
  82
  83		out_origins = iso_table[origins]
  84		out_targets = iso_table[targets]
  85
  86		inv_sym = tr_sym_table[ out_origins, out_targets ]
  87		inv_sts = np.empty(states.size, dtype=states.dtype)
  88
  89		inv_sts[:-1] = out_origins
  90		inv_sts[-1]  = out_targets[-1]
  91
  92		return {
  93			"symbol_index": inv_sym.astype(inputs.dtype),
  94			"state_index":  inv_sts,
  95		}
  96
  97	def __init__( 
  98		self,
  99		states : list[CausalState] | None = None,
 100		transitions : list[Transition] | None = None,
 101		start_state : int = 0,
 102		alphabet : list[str] | None = None,
 103		isoclasses : dict[str,set[str]] | None = None,
 104		name : str = "",
 105		description : str = "" ) : 
 106
 107		self.alphabet    = alphabet or []
 108		self.states      = states or []
 109		self.transitions = transitions or []
 110
 111		self.set_alphabet( self.alphabet )
 112		self.set_states( self.states )
 113		self.set_transitions( self.transitions )
 114
 115		self.start_state = start_state
 116		self.isoclasses  = isoclasses or {}
 117
 118		self.name = name
 119		self.description = description
 120
 121		# to be depreciated
 122		self.isoclass = None
 123
 124		# --- derived --------
 125
 126		self.complexity : dict[str, any] = {}
 127
 128		self.symbol_idx_map : dict[str,int] = {}
 129		self.state_idx_map  : dict[str,int] = {}
 130
 131		self.pi_fractional = None
 132		self.pi = None
 133		self.T  = None
 134		self.msp : MSP = None
 135		self.reverse_am : HMM = None
 136		self.is_q_weighted = False
 137		self.is_minimal = False
 138
 139		# --- const ----------
 140
 141		self.EPS = 1e-12
 142
 143	#-------------------------------------------------------#
 144	#                Overrides / Getters                    #
 145	#-------------------------------------------------------#
 146
 147	@override
 148	def get_states(self) -> list[CausalStates] : 
 149		return self.states
 150
 151	@override
 152	def get_transitions(self) -> list[Transition] : 
 153		return self.transitions
 154
 155	@override
 156	def get_alphabet(self) -> list[Symbol]:
 157		return [ Symbol(a) for a in self.alphabet ]
 158
 159	#-------------------------------------------------------#
 160	#                     Serialization                     #
 161	#-------------------------------------------------------#
 162
 163
 164	def to_dict(self) :
 165		return {
 166			"name"            : self.name,
 167			"description"     : self.description,
 168			"states"          : [ asdict(state)      for state      in self.states      ],
 169			"transitions"     : [ asdict(transition) for transition in self.transitions ],
 170			"alphabet"        : self.alphabet
 171		}
 172
 173	def from_dict( self, config ) :
 174
 175		self.name            = config.name 
 176		self.description     = config.description 
 177		
 178		self.set_states( states=[ 
 179			CausalState( 
 180				name=state[ "name" ],
 181				classes=set( state[ "classes" ] )
 182			)
 183			for state in config.states
 184		] )
 185
 186		self.set_transitions( transitions=[ 
 187			Transition( 
 188				origin_state_idx=tr[ "origin_state_idx" ],
 189				target_state_idx=tr[ "target_state_idx" ],
 190				prob=tr[ "prob" ],
 191				symbol_idx=tr[ "symbol_idx" ]
 192			)
 193			for tr in config.transitions
 194		] )
 195
 196		self.set_alphabet( alphabet=config.alphabet ) 
 197
 198	def save_config(
 199		self, 
 200		path : Path, 
 201		with_complexity : bool = False, 
 202		with_block_convergence :  bool = False,
 203		with_structural_properties : bool = False,
 204		with_causal_properties : bool = False ) :
 205
 206		config = self.to_dict()
 207
 208		if with_complexity :
 209			
 210			complexity = self.get_complexities( 
 211				with_block_convergence=with_block_convergence,
 212				with_reversal=False # not implemented
 213			)
 214
 215			config[ "complexity" ] = complexity
 216
 217		config[ "structural_properties" ] = {
 218			"unifilar"              : self.is_unifilar(),
 219			"row_stochastic"        : self.is_row_stochastic(),
 220			"strongly_connected"    : self.is_strongly_connected(),
 221			"aperiodic"             : self.is_aperiodic(),
 222			"minimal"               : self.is_minimal_as_dfa( topological_only=False ),
 223			"topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ),
 224			"is_epsilon_machine"    : self.is_epsilon_machine()
 225		}
 226
 227		# Not implemented
 228		# config[ "causal_properties" ] = {
 229		# 	"k_X"                        : self.k_X(),
 230		# 	"R"                          : self.R(),
 231		# 	"causally_reversible"        : self.causally_reversible(),
 232		# 	"microscopically_reversible" : self.microscopically_reversible(),
 233		# }
 234
 235		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
 236			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
 237
 238	def from_file( self, path : Path ) :
 239		with open( Path / "am_config.json", "r" ) as f:
 240			config = json.load(f)
 241		self.from_dict()
 242
 243	#-------------------------------------------------------#
 244	#             Setters and State Management              #
 245	#-------------------------------------------------------#
 246	
 247	def _invalidate(self) :
 248
 249		"""
 250		Invalidate
 251
 252		Reset all derived properties so they will be recomputed when requested later
 253		"""
 254
 255		self.complexity = {}
 256		self.T = None
 257		self.T_x = None
 258		self.pi = None
 259		self.pi_fractional = None
 260		self.msp = None 
 261		self.reverse_am = None
 262		self.is_q_weighted = False
 263		self.is_minimal = False
 264
 265	def set_states( self, states : list[CausalState] ) :
 266		self._invalidate()
 267		self.states = states.copy()
 268		self.state_idx_map = {}
 269		for idx, state in enumerate( self.states ) :
 270			self.state_idx_map[ state.name ] = idx
 271
 272	def set_alphabet( self, alphabet : list[str] ) :
 273
 274		self._invalidate()
 275
 276		old_alphabet = self.alphabet.copy()
 277
 278		aSet = set()
 279		aSet.update( alphabet )
 280
 281		self.alphabet = sorted(list(aSet))
 282
 283		self.symbol_idx_map = {}
 284		for idx, symbol in enumerate( self.alphabet ) :
 285			self.symbol_idx_map[ symbol ] = idx
 286
 287		for i, tr in enumerate( self.transitions ) :
 288			symbol = old_alphabet[ tr.symbol_idx ]
 289			self.transitions[ i ] = Transition(
 290				origin_state_idx=tr.origin_state_idx,
 291				target_state_idx=tr.target_state_idx,
 292				prob=tr.prob,
 293				symbol_idx=self.symbol_idx_map[ symbol ]
 294			)
 295
 296	def set_transitions( self, transitions : list[Transition] ) :
 297		self._invalidate()
 298		self.transitions = transitions.copy()
 299
 300	def extend_states( self, states : list[CausalState] ) :
 301		self.set_states( self.states + states )
 302
 303	def extend_alphabet( self, alphabet : list[str] ) :
 304		self.set_alphabet( self.alphabet + alphabet  )
 305
 306	def extend_transitions( self, transitions : list[Transition] ) :
 307		self.set_transitions( self.transitions + transitions  )
 308
 309	def get_complexity_measure_if_exists(self, measure ) :
 310		m = self.complexity.get( measure, None )
 311		return m
 312
 313	def set_complexity_measure(self, measure, value ) :
 314		self.complexity[ measure ] = value
 315
 316	#-------------------------------------------------------#
 317	#                   Get and Compute                     #
 318	#-------------------------------------------------------#
 319
 320	def get_complexities( 
 321		self, 
 322		with_block_convergence=False, 
 323		with_reversal=False ) :
 324
 325		directly_calculable = [
 326			self.C_mu,
 327			self.h_mu,
 328			self.H_1,
 329			self.rho_mu
 330		]
 331
 332		requires_block_convergence = [
 333			self.E, 
 334			self.T_inf,
 335			self.S,
 336			self.Chi
 337		]
 338
 339		requires_reversal = [
 340			self.q_mu,
 341			self.r_mu,
 342			self.b_mu
 343		]
 344			
 345		complexities = { m.__name__ : m() for m in directly_calculable }
 346
 347		if with_block_convergence :
 348
 349			complexities |= { m.__name__ : m() for m in requires_block_convergence }
 350
 351			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
 352				if key in self.complexity :
 353					complexities[ key ] = self.complexity[ key ]
 354
 355		if with_reversal :
 356			complexities |= { m.__name__ : m() for m in requires_reversal }
 357
 358		return complexities
 359
 360	#-------------------------------------------------------#
 361
 362	def get_metadata(self) :
 363		return {
 364			"name" : self.name,
 365			'complexity' : self.complexity,
 366			"description" : self.description
 367		}
 368
 369	def get_transition_matrix(self) :
 370
 371		if self.T  is not None :
 372			return self.T
 373
 374		n_states = len( self.states )
 375		T = np.zeros((n_states, n_states))
 376
 377		for tr in self.transitions :    
 378			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
 379
 380		self.T = T
 381
 382		return self.T
 383
 384	#-------------------------------------------------------#
 385
 386	def get_T_X(self) :
 387
 388		if self.T_x  is not None :
 389			return self.T_x
 390
 391		n_states  = len( self.states )
 392		n_symbols = len( self.alphabet )
 393
 394		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
 395
 396		for tr in self.transitions :
 397			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
 398
 399		self.T_x = T_x
 400		return self.T_x
 401
 402	#-------------------------------------------------------#
 403
 404	def get_msp_qw(
 405		self,
 406		exact_state_cap: int = 1000,
 407		verbose: bool = True,
 408	):
 409		if self.msp is not None:
 410			return self.msp
 411
 412		try : 
 413
 414			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
 415
 416			self.msp = compute_msp_exact(
 417				T_x=self.get_Tx_fractional(),
 418				pi=self.get_fractional_stationary_distribution(),
 419				n_states=len(self.states),
 420				alphabet=self.alphabet,
 421				exact_state_cap=1000,
 422				bool = True
 423			)
 424
 425			return self.msp 
 426
 427		except RuntimeError as e :
 428			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
 429
 430		return self.get_msp()
 431
 432	def get_msp(
 433		self,
 434		exact_state_cap: int = 175_000,
 435		jsd_eps:         float = 1e-7,
 436		k_ann:           int   = 50,
 437		verbose                = True,
 438	) :
 439
 440		if self.msp is not None:
 441			return self.msp
 442	 
 443		T_x = self.get_T_X()
 444		pi  = self.get_stationary_distribution()
 445
 446		T_stacked      = np.stack(T_x)
 447		n_symbols      = len(self.alphabet)
 448		n_input_states = T_stacked.shape[1]
 449
 450		print( "\nComputing Mixed State Presentation\n" )
 451
 452		self.msp = compute_msp( 
 453			T_x=T_x,
 454			pi=pi,
 455			n_states=len(self.states),
 456			alphabet=self.alphabet,
 457			exact_state_cap=exact_state_cap,
 458			verbose=verbose
 459		)
 460
 461		return self.msp
 462
 463	def get_reverse_am(self) :
 464
 465		if self.reverse_am is not None:
 466			return self.reverse_am
 467
 468		pi = self.get_stationary_distribution()
 469		self.reverse_am = copy.deepcopy(self)
 470		
 471		new_transitions = []
 472		for tr in self.transitions:
 473			i = tr.target_state_idx
 474			j = tr.origin_state_idx
 475			
 476			p_reversed = (pi[j] * tr.prob) / pi[i]
 477			
 478			new_transitions.append(
 479				Transition(
 480					origin_state_idx=i,
 481					target_state_idx=j,
 482					prob=p_reversed,
 483					symbol_idx=tr.symbol_idx
 484				)
 485			)
 486
 487		self.reverse_am.set_transitions(new_transitions)
 488
 489		if self.reverse_am.is_epsilon_machine():
 490			return self.reverse_am
 491
 492		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
 493
 494		self.reverse_am.set_states( rmsp.states )
 495		self.reverse_am.set_transitions( rmsp.transitions )
 496		self.reverse_am.msp = rmsp
 497		self.reverse_am.start_state = 0
 498
 499		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
 500		self.reverse_am.minimize()
 501
 502		return self.reverse_am
 503
 504	#-------------------------------------------------------#
 505
 506	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
 507
 508		self.to_q_weighted()
 509
 510		n_states  = len( self.states )
 511		n_symbols = len( self.alphabet )
 512
 513		T_x = []
 514
 515		for x in range( n_symbols ) :
 516			T_x.append( [] )
 517			for i in range( n_states ) :
 518				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
 519
 520		for tr in self.transitions :
 521			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
 522
 523		return T_x
 524
 525	def get_T_sympy( self ) :
 526
 527		self.to_q_weighted()
 528
 529		n = len( self.states )
 530		T = sympy.zeros( n, n )
 531
 532		for tr in self.transitions :
 533			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
 534
 535		return T
 536
 537	def get_fractional_stationary_distribution(self) :
 538
 539		T = self.get_T_sympy()
 540
 541		if self.pi_fractional is not None :
 542			return self.pi_fractional
 543
 544		G = self.as_digraph()
 545
 546		if not nx.is_strongly_connected(G):
 547			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 548
 549		self.pi_fractional = solve_for_pi_fractional( T )
 550
 551		return self.pi_fractional
 552
 553	def get_stationary_distribution(self):
 554
 555		if self.pi is not None :
 556			return self.pi
 557
 558		G = self.as_digraph()
 559		
 560		if not nx.is_strongly_connected(G):
 561			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 562
 563		T = self.get_transition_matrix()
 564		return solve_for_pi( T )		
 565
 566	#-------------------------------------------------------#
 567	#                 Complexity Measures                   #
 568	#-------------------------------------------------------#
 569	
 570	def C_mu( self ) :
 571
 572		"""
 573		c_mu (statistical complexity, or forecasting complexity)
 574
 575		Interpretations
 576			- The amount of historical information a process stores.
 577			- The amount of structure in a process.
 578		"""
 579
 580		m = self.get_complexity_measure_if_exists( "C_mu" )
 581
 582		if m is not None :
 583			return m
 584
 585		pi = self.get_stationary_distribution()
 586
 587		h = 0
 588		for i, pr in enumerate( pi ) :
 589			
 590			if pr < self.EPS :
 591				continue
 592
 593			h += -pr * np.log2( pr )
 594
 595		self.set_complexity_measure( "C_mu", h )
 596
 597		return h
 598
 599	#-------------------------------------------------------#
 600
 601	def h_mu( self ) :
 602
 603		"""
 604		h_mu (entropy rate)
 605
 606		Interpretation
 607			- Intrinsic Randomness in the process.
 608		"""
 609
 610		m = self.get_complexity_measure_if_exists( "h_mu" )
 611
 612		if m is not None :
 613			return m
 614
 615		T  = self.get_transition_matrix()
 616		pi = self.get_stationary_distribution()
 617
 618		n_states = pi.size
 619
 620		h = 0
 621		for i, pr in enumerate( pi ) :
 622
 623			if pr < self.EPS :
 624				continue
 625
 626			row_entropy = 0
 627			for j in range( len( pi ) ) :
 628
 629				if T[ i, j ]  < self.EPS :
 630					continue
 631
 632				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
 633
 634			h += pr * row_entropy
 635
 636		self.set_complexity_measure( "h_mu", h )
 637
 638		return h
 639
 640	#-------------------------------------------------------#
 641
 642	def H_1(self):
 643
 644		"""
 645		H_1 (single-symbol entropy, or marginal entropy)
 646		Interpretations:
 647			- How uncertain you are about a single measurement with no context.
 648			- The total information content per symbol before seeing any neighbors.
 649			- The maximum possible entropy rate, achieved by a memoryless process.
 650		"""
 651
 652		m = self.get_complexity_measure_if_exists("H_1")
 653		if m is not None:
 654			return m
 655
 656		pi  = self.get_stationary_distribution()
 657		T_X = self.get_T_X()  # dict: symbol -> matrix
 658
 659		h = 0.0
 660		for T_x in T_X:
 661			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
 662			p_sym = 0.0
 663			for i, pr in enumerate(pi):
 664				if pr < self.EPS:
 665					continue
 666				p_sym += pr * T_x[i, :].sum()
 667
 668			if p_sym < self.EPS:
 669				continue
 670			h -= p_sym * np.log2(p_sym)
 671
 672		self.set_complexity_measure("H_1", h)
 673		return h
 674
 675	#-------------------------------------------------------#
 676
 677	def rho_mu(self) :
 678		
 679		"""
 680		rho_mu (anticipated information, or single-symbol redundancy)
 681		Interpretations:
 682			- How much the past reduces your uncertainty about the next symbol.
 683			- The portion of each symbol's information that was already predictable.
 684			- The difference between naive per-symbol uncertainty and true randomness.
 685			- Equal to H_1 - h_mu: the gap closed by knowing context.
 686		"""
 687
 688		m = self.get_complexity_measure_if_exists("rho_mu")
 689		
 690		if m is not None:
 691			return m
 692
 693		rho = self.H_1() - self.h_mu()
 694		
 695		self.set_complexity_measure("rho_mu", rho)
 696		
 697		return rho
 698
 699	#-------------------------------------------------------#
 700
 701	def block_convergence( self ) :
 702
 703		"""
 704		block_convergence
 705			- Estimates a range of complexity measures using block entropy convergence.
 706			- E, S, T, T_L, H_L, h_mu_L, H_sync
 707		"""
 708
 709		trs = [ [] for _ in range( len( self.states ) ) ]
 710		for tr in self.transitions :
 711			trs[ tr.origin_state_idx ].append( ( 
 712				tr.symbol_idx, 
 713				float( tr.prob ),
 714				tr.target_state_idx ) )
 715
 716		pi = self.get_stationary_distribution()
 717
 718		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
 719		branches = [(1.0, list(state_dist))]
 720
 721		print( "\nComputing Block Entropy\n" )
 722
 723		C = am_fast.block_entropy_convergence(
 724			h_mu            = self.h_mu(),
 725			n_states        = len( self.states ),
 726			n_symbols       = len( self.alphabet ),
 727			convergence_tol = 1e-6,
 728			precision       = 10,
 729			eps             = 1e-25,
 730			branches        = branches,
 731			trans           = trs,
 732			max_branches    = 30_000_000
 733		)
 734
 735		print( "Done\n" )
 736
 737		self.set_complexity_measure( f"E",          C.E )
 738		self.set_complexity_measure( f"S",          C.S )
 739		self.set_complexity_measure( f"T_inf",      C.T )
 740		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
 741		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
 742		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
 743		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
 744
 745		return C
 746
 747	#-------------------------------------------------------#
 748
 749	def E( self ) :
 750
 751		"""
 752		E (excess entropy)
 753
 754		Interpretations:
 755			- Mutual information from past to future I(Past;Future).
 756			- How much total information from the past reduces uncertainty in the future
 757			- Amount of (apparent) stored information
 758
 759		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 760		https://arxiv.org/abs/1309.3792
 761		"""
 762
 763		m = self.get_complexity_measure_if_exists( "E" )
 764
 765		if m is not None :
 766			return m
 767
 768		try : 
 769			msp = self.get_msp()
 770			E, S, T = msp.get_E_S_T()
 771			self.set_complexity_measure( "E", E )
 772			self.set_complexity_measure( "S", S )
 773			self.set_complexity_measure( "T_inf", T )
 774			
 775		except Exception as e :
 776
 777			print( f"MSP failed {e}" )
 778			exit()
 779
 780			C = self.block_convergence()	
 781			E = C.E
 782			self.set_complexity_measure( "E", E )
 783
 784		return E
 785
 786	#-------------------------------------------------------#
 787
 788	def S( self ) :
 789
 790		"""
 791		S (synchronization information)
 792
 793		Interpretation
 794			-synchronization information
 795
 796		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 797		https://arxiv.org/abs/1309.3792
 798		"""
 799
 800		m = self.get_complexity_measure_if_exists( "S" )
 801
 802		if m is not None :
 803			return m
 804
 805		try : 
 806			msp = self.get_msp()
 807			E, S, T = msp.get_E_S_T()
 808			self.set_complexity_measure( "E", E )
 809			self.set_complexity_measure( "S", S )
 810			self.set_complexity_measure( "T_inf", T )
 811
 812		except Exception as e :
 813			print( f"{e} \nFalling back to iterative estimation.")
 814			exit()
 815
 816			C = self.block_convergence()	
 817			S = C.S
 818			self.set_complexity_measure( "S", S )
 819
 820		return S
 821
 822	#-------------------------------------------------------#
 823
 824	def T_inf( self ) :
 825
 826		"""
 827		T_inf (transient information)
 828
 829		Interpretation
 830			-Transient information
 831
 832		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 833		https://arxiv.org/abs/1309.3792
 834		"""
 835
 836		m = self.get_complexity_measure_if_exists( "T_inf" )
 837
 838		if m is not None :
 839			return m
 840
 841		try : 
 842			msp = self.get_msp()
 843			E, S, T = msp.get_E_S_T()
 844			self.set_complexity_measure( "E", E )
 845			self.set_complexity_measure( "S", S )
 846			self.set_complexity_measure( "T_inf", T )
 847
 848		except Exception as e :
 849			print( f"{e} \nFalling back to iterative estimation.")
 850			exit()
 851			C = self.block_convergence()	
 852			T_inf = C.T
 853			self.set_complexity_measure( "T_inf", T_inf )
 854
 855		return T_inf
 856
 857	#-------------------------------------------------------#
 858
 859	def Chi( self ) :
 860
 861		"""
 862		Chi (crypticity)
 863
 864		Interpretation
 865			- Difference between internal stored information and apparent information to an observer
 866			- How muching information is hiding in the system
 867		"""
 868
 869		m = self.get_complexity_measure_if_exists( "Chi" )
 870
 871		if m is not None :
 872			return m
 873
 874		Chi = self.C_mu() - self.E()
 875
 876		self.set_complexity_measure( "Chi", Chi )
 877
 878		return Chi
 879
 880	#-------------------------------------------------------#
 881
 882	def k_X( self ) :
 883
 884		"""
 885		k_x (cryptic order)
 886
 887		Interpretation
 888			
 889		"""
 890
 891		m = self.get_complexity_measure_if_exists( "k_X" )
 892
 893		if m is not None :
 894			return m
 895
 896		raise NotImplementedError("Not implemented")
 897
 898	#-------------------------------------------------------#
 899
 900	def R( self ) :
 901
 902		"""
 903		R (Markov order)
 904		Interpretations:
 905			- The smallest L such that h_mu(L) = h_mu.
 906			- The context length beyond which additional history gives no predictive benefit.
 907			- The memory length of the process — how far back correlations extend.
 908			- The ideal context window length for a predictor.
 909			- R = 0 means the process is IID (memoryless).
 910			- R = infinity means the process has unbounded memory (non-finitary).
 911			— only report if convergence was clean, otherwise None
 912		"""
 913
 914		m = self.get_complexity_measure_if_exists( "R" )
 915
 916		if m is not None :
 917			return m
 918
 919		raise NotImplementedError("Not implemented")
 920
 921	#-------------------------------------------------------#
 922
 923	def b_mu( self ) :
 924
 925		"""
 926		b_mu - requires reverse HMM
 927
 928
 929		Interpretation
 930			
 931		"""
 932
 933		m = self.get_complexity_measure_if_exists( "b_mu" )
 934
 935		if m is not None :
 936			return m
 937
 938		raise NotImplementedError("Not implemented")
 939
 940	#-------------------------------------------------------#
 941
 942	def r_mu( self ) :
 943
 944		"""
 945		r_mu - requires reverse HMM
 946
 947
 948		Interpretation
 949			
 950		"""
 951
 952		m = self.get_complexity_measure_if_exists( "r_mu" )
 953
 954		if m is not None :
 955			return m
 956
 957		raise NotImplementedError("Not implemented")
 958
 959	#-------------------------------------------------------#
 960
 961	def q_mu( self ) :
 962
 963		"""
 964		q_mu 
 965
 966		Interpretation
 967		"""
 968
 969		m = self.get_complexity_measure_if_exists( "q_mu" )
 970
 971		if m is not None :
 972			return m
 973
 974		raise NotImplementedError("Not implemented")
 975
 976	#-------------------------------------------------------#
 977	#                      Properties                       #
 978	#-------------------------------------------------------#
 979
 980	def causally_reversible( self ) :
 981
 982		"""
 983		causally_reversible
 984		"""
 985
 986		m = self.get_complexity_measure_if_exists( "causally_reversible" )
 987
 988		if m is not None :
 989			return m
 990
 991		raise NotImplementedError("Not implemented")
 992
 993	#-------------------------------------------------------#
 994
 995	def microscopically_reversible( self ) :
 996
 997		"""
 998		microscopically_reversible
 999		"""
1000
1001		m = self.get_complexity_measure_if_exists( "microscopically_reversible" )
1002
1003		if m is not None :
1004			return m
1005
1006		raise NotImplementedError("Not implemented")
1007
1008	#-------------------------------------------------------#
1009
1010	def is_row_stochastic(self) :
1011
1012		"""
1013		is_row_stochastic
1014			-All states have outgoing transition probabilities which sum 1
1015		"""
1016
1017		sums = np.zeros( len( self.states ) )
1018		for tr in self.transitions :
1019			sums[ tr.origin_state_idx ] += tr.prob
1020		return np.allclose( sums, 1.0 )
1021
1022	#-------------------------------------------------------#
1023
1024	def is_unifilar(self) :
1025
1026		"""
1027		is_unifilar
1028			-All states have no more than one outgoing transition per-symbol
1029		"""
1030
1031		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
1032
1033		for tr in self.transitions : 
1034		
1035			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
1036				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
1037		
1038			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
1039				return False
1040		
1041		return True
1042
1043	#-------------------------------------------------------#
1044
1045	def is_strongly_connected(self) :
1046
1047		"""
1048		is_strongly_connected
1049			-Every state is reachable from every other state
1050		"""
1051
1052		return nx.is_strongly_connected( self.as_digraph() )
1053
1054	#-------------------------------------------------------#
1055
1056	def is_aperiodic(self) :
1057
1058		"""
1059		is_aperiodic
1060			-"A strongly connected directed graph is aperiodic 
1061				if there is no integer k > 1 that divides the length of every cycle in the graph."
1062		"""
1063
1064		return nx.is_aperiodic( self.as_digraph() )
1065
1066	#-------------------------------------------------------#
1067
1068	def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) :
1069
1070		with_probs = not topological_only
1071
1072		# Construct the DFA
1073		dfa = self.as_dfa( with_probs=with_probs )
1074
1075		# Minimize the DFA
1076		#dfa = dfa.minify(retain_names=True)
1077		dfa = am_fast.minify_cpp( dfa, retain_names=True )
1078
1079		# check we have minimal number of states
1080		if len( dfa.states ) != len( self.states ) :
1081			if verbose : 
1082				print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" )
1083			return False
1084
1085		return True
1086
1087	def is_topological_epsilon_machine( self, verbose=True ) :
1088
1089		"""
1090		is_topological_epsilon_machine
1091			-see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024
1092		"""
1093
1094		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1095			if verbose : 
1096				print( f"Either non unifilar or not strongly connected" )
1097			return False
1098		else :
1099			return self.is_minimal_as_dfa( topological_only=True, verbose=verbose )
1100
1101	def is_epsilon_machine( self, verbose=True ) :
1102
1103		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1104			if verbose : 
1105				print( f"Either non unifilar or not strongly connected" )
1106			return False
1107		else :
1108			return self.is_minimal_as_dfa( topological_only=False, verbose=verbose )
1109
1110	#-------------------------------------------------------#
1111	#                      Properties                       #
1112	#-------------------------------------------------------#
1113
1114	def minimize(self, retain_names: bool = True):
1115
1116		"""
1117		minimize
1118			-Simplify to minimal form, using Myhill Nerode
1119		"""
1120
1121		if self.is_minimal :
1122			return
1123
1124		start = time.perf_counter()
1125
1126		if not self.is_unifilar():
1127			raise ValueError(
1128				"DFA minimization is not valid for non-unifilar HMMs"
1129			)
1130
1131		was_strongly_connected = self.is_strongly_connected()
1132
1133		was_row_stochastic = self.is_row_stochastic()
1134		n_states_before = len(self.states)
1135
1136		dfa = self.as_dfa(with_probs=True)
1137
1138		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1139		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1140
1141		# Build lookup from original state index -> CausalState object
1142		orig_state   = {i: s for i, s in enumerate(self.states)}
1143		eq_list      = list(min_dfa.states)
1144
1145		start_eq = min_dfa.initial_state
1146		
1147		# Separate the start state, then sort the rest by the 
1148		# smallest original state index inside each equivalence class.
1149		other_eqs = [eq for eq in eq_list if eq != start_eq]
1150		other_eqs.sort(key=lambda eq: min(eq))
1151
1152		# Recombine so start eq comes first, followed by the sorted remaining classes
1153		eq_list = [start_eq] + other_eqs
1154		# ----------------------------------------------------------
1155
1156		# Recompute eq_to_idx with the new ordering
1157		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1158
1159		# new_start is now guaranteed to be 0
1160		new_start = 0
1161
1162		# Map each original state index -> its equivalence class
1163		# Guard: minify() silently drops unreachable states
1164		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1165
1166		# Build lookup from original state index -> its transitions
1167		orig_trs = defaultdict(list)
1168		for t in self.transitions:
1169			orig_trs[t.origin_state_idx].append(t)
1170
1171		new_trs = []
1172		for eq in min_dfa.states:
1173			rep        = next(iter(eq))
1174			origin_idx = eq_to_idx[eq]
1175			for t in orig_trs[rep]:
1176				target_eq  = orig_to_eq[t.target_state_idx]
1177				target_idx = eq_to_idx[target_eq]
1178				new_trs.append(Transition(
1179					origin_state_idx = origin_idx,
1180					target_state_idx = target_idx,
1181					prob             = t.prob,
1182					symbol_idx       = t.symbol_idx,
1183				))
1184
1185		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1186
1187		# Compute new names
1188		if retain_names:
1189			new_names = [
1190				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1191				else members[0].name
1192				for members in members_list
1193			]
1194		else:
1195			new_names = [str(j) for j in range(len(eq_list))]
1196
1197		old_name_to_new_name = {
1198			m.name: new_names[j]
1199			for j, members in enumerate(members_list)
1200			for m in members
1201		}
1202
1203		# Build the new states, preserving classes and isomorphs regardless of naming
1204		new_states = []
1205		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1206			classes   = set().union(*(m.classes for m in members))
1207			isomorphs = {
1208				old_name_to_new_name.get(iso, iso)
1209				for m in members
1210				for iso in m.isomorphs
1211				if old_name_to_new_name.get(iso, iso) != name
1212			}
1213			new_states.append(CausalState(
1214				name      = name,
1215				classes   = classes,
1216				isomorphs = isomorphs,
1217			))
1218
1219		self.set_states(new_states)
1220		self.set_transitions(new_trs)
1221		self.start_state = new_start
1222
1223		if n_states_before == len(new_states) :
1224			print( f"{n_states_before} state HMM was already minimal." )
1225		else :
1226			print( f"Minimized from {n_states_before} to {len(new_states)}" )
1227
1228		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1229			raise RuntimeError(
1230				f"Minimization broke strongly connected"
1231			)
1232
1233		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1234			raise RuntimeError(
1235				f"Minimization broke row stochasticity"
1236			)
1237
1238		self.is_minimal = True
1239
1240	#-------------------------------------------------------#
1241	#                      Modifiers                        #
1242	#-------------------------------------------------------#
1243
1244
1245	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1246
1247		# get equivalent networkx graph
1248		G = self.as_digraph()
1249
1250		# if already strongly connected, nothing to do
1251		if not nx.is_strongly_connected( G ) :
1252
1253			start = time.perf_counter()
1254			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1255
1256			# decompose into strongly connected components and sort by length
1257			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1258			subgraph_nodes.sort(key=len)
1259			component_state_set = subgraph_nodes[-1]
1260
1261			# Take the largest strongly connected component (as list of state names)
1262			component_states = sorted( list( component_state_set ) )
1263
1264			# make temporary copies of the old transitions and states
1265			old_transitions = [
1266				Transition(
1267					origin_state_idx=tr.origin_state_idx,
1268					target_state_idx=tr.target_state_idx,
1269					prob=tr.prob,
1270					symbol_idx=tr.symbol_idx
1271				)
1272
1273				for tr in self.transitions
1274			]
1275
1276			old_states = [
1277				CausalState(
1278					name=s.name,
1279					classes=s.classes,
1280					isomorphs=s.isomorphs
1281				)
1282				for s in self.states
1283			]
1284
1285			self.set_states(
1286				states=[ 
1287					state
1288					for i, state in enumerate( old_states ) if i in component_state_set
1289				]
1290			)
1291
1292			# we will build new transition list based on those belonging to the component
1293			self.set_transitions( transitions= [] )
1294
1295			# for tracking which new transitions leave each state
1296			transitions_from_state = { state : set() for state in component_states }
1297			new_transitions = []
1298
1299			for tr in old_transitions :
1300
1301				origin_state_name = old_states[ tr.origin_state_idx ].name
1302				target_state_name = old_states[ tr.target_state_idx ].name
1303
1304				# skip transitions that connect separate strongly connected components
1305				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1306					continue
1307
1308				# track transitions (by index in new transitions list) that leave this state
1309				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1310
1311				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1312				my_target_state_idx = self.state_idx_map[ target_state_name ]
1313
1314				new_transitions.append( 
1315					Transition(
1316						origin_state_idx=my_origin_state_idx,
1317						target_state_idx=my_target_state_idx,
1318						prob=tr.prob,
1319						symbol_idx=tr.symbol_idx
1320					)  )
1321
1322			self.set_transitions( transitions=new_transitions )
1323
1324			# if we removed an outgoing transition from a state, we need to distribute its probability 
1325			# among the remaining outgoing transitions from the state
1326			for state in component_states :
1327				
1328				# get the set of transitions leaving this state
1329				state_trs = transitions_from_state[ state ]
1330
1331				# sum the probabilities of the outgoing transitions from the state
1332				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1333
1334				# how much probability is missing
1335				diff = 1.0 - p_sum
1336
1337				# if significant difference
1338				if abs( diff ) > self.EPS : 
1339
1340					# calculate how much of the difference each transition gets
1341					adjustment = diff / len( state_trs )
1342					
1343					# update the transitions
1344					for i in state_trs :
1345
1346						# transition with origional probability
1347						tr = self.transitions[ i ]
1348
1349						# adjusted probability
1350						self.transitions[ i ] = Transition(
1351							origin_state_idx=tr.origin_state_idx, 
1352							target_state_idx=tr.target_state_idx, 
1353							prob=tr.prob + adjustment, 
1354							symbol_idx=tr.symbol_idx )
1355
1356			if rename_states :
1357				self.set_states( [
1358					CausalState( name=f"{i}" )
1359					for i, s in enumerate( self.states )
1360				] )
1361
1362
1363	def to_q_weighted( self, denominator_limit=1000 ) :
1364
1365		"""
1366		to_q_weighted
1367			-Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq
1368		"""
1369
1370		if self.is_q_weighted :
1371			return
1372
1373		if not self.is_row_stochastic() :
1374			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1375
1376		t_from = [[] for _ in range(len(self.states))]
1377
1378		for i, tr in enumerate( self.transitions ) :
1379			t_from[ tr.origin_state_idx ].append( i )
1380
1381		new_transitions = []
1382		for t_list in t_from :
1383			
1384			if not t_list : 
1385				continue
1386
1387			p_q_sum = Fraction(0,1)
1388			p_qs = []
1389
1390			for t_idx in t_list :
1391			
1392				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1393				p_q_sum += p_q
1394				p_qs.append( p_q )
1395
1396			if p_q_sum != Fraction(1,1) :
1397				
1398				max_pq_i = np.argmax( p_qs )
1399				max_oq = p_qs[ max_pq_i ]
1400
1401				diff = p_q_sum - Fraction(1,1)
1402
1403				# recurse with higher resolution
1404				if diff > max_oq :
1405					return self.to_q_weighted( denominator_limit*10 )
1406				else :
1407					p_qs[ max_pq_i ] -= diff
1408
1409			for i, t_idx in enumerate( t_list ) :	
1410				new_transitions.append( 
1411					Transition( 
1412						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1413						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1414						prob=float(p_qs[ i ]),
1415						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1416						pq=p_qs[ i ]
1417					)
1418				)
1419
1420		self.set_transitions( new_transitions )
1421		self.is_q_weighted = True
1422
1423	#-------------------------------------------------------#
1424	#                    Data Generation                    #
1425	#-------------------------------------------------------#
1426
1427	def generate_data(
1428		self,
1429		file_prefix: str,
1430		n_gen: int,
1431		include_states: bool,
1432		isomorphic_shifts : set[int]=None,
1433		random_seed : int=42 ) -> dict[any] : 
1434
1435		trs = [ [] for _ in range( len( self.states ) ) ]
1436		for tr in self.transitions :
1437			trs[ tr.origin_state_idx ].append( ( 
1438				tr.symbol_idx, 
1439				float( tr.prob ),
1440				tr.target_state_idx ) )
1441
1442		data = am_fast.generate_data(
1443			n_gen=n_gen,
1444			start_state=self.start_state,
1445			transitions=trs,
1446			alphabet=sorted(list(self.alphabet)),
1447			include_states=include_states,
1448			random_seed=random_seed
1449		)
1450
1451		if isomorphic_shifts is not None :
1452
1453			if not include_states :
1454				raise ValueError( "Isomorphic inversion requires include_states=True" )
1455
1456			data[ "isomorphic_shifts" ] = {}
1457
1458			for shift in isomorphic_shifts :
1459
1460				shifted = HMM.isomorphic_shift(
1461					symbol_indices=data[ "symbol_index" ], 
1462					state_indices=data[ "state_index" ], 
1463					m=self,
1464					shift=shift
1465				)
1466
1467				data[ "isomorphic_shifts" ][ shift ] = {
1468					"symbol_index" : shifted[ "symbol_index" ],
1469					"state_index"  : shifted[ "state_index" ]
1470				}
1471
1472		am_fast.save_data(
1473			data=data,
1474			file_prefix=file_prefix,
1475			alphabet=sorted(list(self.alphabet)),
1476			n_states=len( self.states ),
1477			start_state=self.start_state,
1478			random_seed=random_seed,
1479			metadata=self.get_metadata() )
1480
1481		return data
1482
1483	#-------------------------------------------------------#
1484	#                 Basic Visualization                   #
1485	#-------------------------------------------------------#
1486
1487	def draw_graph(
1488		self,
1489		engine     : str  = 'dot',
1490		output_dir : Path = Path('.'),
1491		show       : bool = True
1492	) -> None :
1493
1494		G = self.as_digraph()
1495
1496		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1497		
1498		am_vis.draw_graph( 
1499			self, 
1500			output_dir=output_dir, 
1501			title="am_graph", 
1502			view=show, 
1503			subgraphs=subgraphs, 
1504			engine=engine )
1505
1506	#-------------------------------------------------------#
1507	#                   Alternate Forms                     #
1508	#-------------------------------------------------------#
1509
1510
1511	def as_digraph( self ) :
1512
1513		"""
1514		Build nx DiGraph based on the HMM, without edge probabilities or symbols
1515		"""
1516
1517		G = nx.DiGraph()
1518		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1519
1520		for tr in self.transitions :
1521			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1522
1523		return G
1524
1525	def as_dfa( self, with_probs : bool ) :
1526
1527		precision=8
1528
1529		def edge_label( symb, prob ) :
1530			return f"({symb},{round(prob, precision)})"
1531
1532		# Build states, symbols, and transitions 
1533		dfa_states  = { i for i, _ in enumerate( self.states ) }
1534			
1535		if not with_probs :
1536			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1537		else : 
1538			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1539
1540		dfa_transitions = defaultdict(dict)
1541
1542		if not with_probs :
1543			for t in self.transitions :
1544				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1545		else :
1546			for t in self.transitions :
1547				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1548
1549		# Construct the DFA
1550		return DFA(
1551			states=dfa_states,
1552			input_symbols=dfa_symbols,
1553			transitions=dfa_transitions,
1554			initial_state=self.start_state,
1555			allow_partial=True,
1556			final_states={ 
1557				s for s in dfa_states
1558			}
1559		)
class HMM(amachine.am_machine.Machine):
  38class HMM(Machine) :
  39
  40	@staticmethod
  41	def isomorphic_shift(
  42		input_symbol_indices: np.ndarray,
  43		input_state_indices:  np.ndarray,
  44		m : HMM,
  45		shift : int = 1
  46	) -> dict[str, np.ndarray]:
  47
  48		if not any( state.isomorphs for state in m.states ):
  49			raise ValueError("HMM has no states with isomorphs")
  50
  51		inputs = np.asarray(input_symbol_indices)
  52		states = np.asarray(input_state_indices)
  53
  54		n_states = len(m.states)
  55
  56		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
  57
  58		for tr in m.transitions:
  59			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
  60
  61		# Build isomorph remapping: identity by default, overridden where isomorphs exist
  62		iso_table = np.arange(n_states, dtype=np.int32)
  63
  64		for i, state in enumerate(m.states):
  65			if state.isomorphs:
  66				isomorph       = sorted(state.isomorphs)[0]
  67				iso_table[ i ] = m.state_idx_map[isomorph]
  68
  69		for i, state in enumerate(m.states):
  70			if state.isomorphs:
  71
  72				# extend the isomorph list to include the identity
  73				isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] )
  74
  75				# find the identity index
  76				pos = isormorphs_with_identity.index( i )
  77
  78				# cyclical shift 
  79				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
  80
  81		origins = states[:-1]
  82		targets = states[1:]
  83
  84		out_origins = iso_table[origins]
  85		out_targets = iso_table[targets]
  86
  87		inv_sym = tr_sym_table[ out_origins, out_targets ]
  88		inv_sts = np.empty(states.size, dtype=states.dtype)
  89
  90		inv_sts[:-1] = out_origins
  91		inv_sts[-1]  = out_targets[-1]
  92
  93		return {
  94			"symbol_index": inv_sym.astype(inputs.dtype),
  95			"state_index":  inv_sts,
  96		}
  97
  98	def __init__( 
  99		self,
 100		states : list[CausalState] | None = None,
 101		transitions : list[Transition] | None = None,
 102		start_state : int = 0,
 103		alphabet : list[str] | None = None,
 104		isoclasses : dict[str,set[str]] | None = None,
 105		name : str = "",
 106		description : str = "" ) : 
 107
 108		self.alphabet    = alphabet or []
 109		self.states      = states or []
 110		self.transitions = transitions or []
 111
 112		self.set_alphabet( self.alphabet )
 113		self.set_states( self.states )
 114		self.set_transitions( self.transitions )
 115
 116		self.start_state = start_state
 117		self.isoclasses  = isoclasses or {}
 118
 119		self.name = name
 120		self.description = description
 121
 122		# to be depreciated
 123		self.isoclass = None
 124
 125		# --- derived --------
 126
 127		self.complexity : dict[str, any] = {}
 128
 129		self.symbol_idx_map : dict[str,int] = {}
 130		self.state_idx_map  : dict[str,int] = {}
 131
 132		self.pi_fractional = None
 133		self.pi = None
 134		self.T  = None
 135		self.msp : MSP = None
 136		self.reverse_am : HMM = None
 137		self.is_q_weighted = False
 138		self.is_minimal = False
 139
 140		# --- const ----------
 141
 142		self.EPS = 1e-12
 143
 144	#-------------------------------------------------------#
 145	#                Overrides / Getters                    #
 146	#-------------------------------------------------------#
 147
 148	@override
 149	def get_states(self) -> list[CausalStates] : 
 150		return self.states
 151
 152	@override
 153	def get_transitions(self) -> list[Transition] : 
 154		return self.transitions
 155
 156	@override
 157	def get_alphabet(self) -> list[Symbol]:
 158		return [ Symbol(a) for a in self.alphabet ]
 159
 160	#-------------------------------------------------------#
 161	#                     Serialization                     #
 162	#-------------------------------------------------------#
 163
 164
 165	def to_dict(self) :
 166		return {
 167			"name"            : self.name,
 168			"description"     : self.description,
 169			"states"          : [ asdict(state)      for state      in self.states      ],
 170			"transitions"     : [ asdict(transition) for transition in self.transitions ],
 171			"alphabet"        : self.alphabet
 172		}
 173
 174	def from_dict( self, config ) :
 175
 176		self.name            = config.name 
 177		self.description     = config.description 
 178		
 179		self.set_states( states=[ 
 180			CausalState( 
 181				name=state[ "name" ],
 182				classes=set( state[ "classes" ] )
 183			)
 184			for state in config.states
 185		] )
 186
 187		self.set_transitions( transitions=[ 
 188			Transition( 
 189				origin_state_idx=tr[ "origin_state_idx" ],
 190				target_state_idx=tr[ "target_state_idx" ],
 191				prob=tr[ "prob" ],
 192				symbol_idx=tr[ "symbol_idx" ]
 193			)
 194			for tr in config.transitions
 195		] )
 196
 197		self.set_alphabet( alphabet=config.alphabet ) 
 198
 199	def save_config(
 200		self, 
 201		path : Path, 
 202		with_complexity : bool = False, 
 203		with_block_convergence :  bool = False,
 204		with_structural_properties : bool = False,
 205		with_causal_properties : bool = False ) :
 206
 207		config = self.to_dict()
 208
 209		if with_complexity :
 210			
 211			complexity = self.get_complexities( 
 212				with_block_convergence=with_block_convergence,
 213				with_reversal=False # not implemented
 214			)
 215
 216			config[ "complexity" ] = complexity
 217
 218		config[ "structural_properties" ] = {
 219			"unifilar"              : self.is_unifilar(),
 220			"row_stochastic"        : self.is_row_stochastic(),
 221			"strongly_connected"    : self.is_strongly_connected(),
 222			"aperiodic"             : self.is_aperiodic(),
 223			"minimal"               : self.is_minimal_as_dfa( topological_only=False ),
 224			"topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ),
 225			"is_epsilon_machine"    : self.is_epsilon_machine()
 226		}
 227
 228		# Not implemented
 229		# config[ "causal_properties" ] = {
 230		# 	"k_X"                        : self.k_X(),
 231		# 	"R"                          : self.R(),
 232		# 	"causally_reversible"        : self.causally_reversible(),
 233		# 	"microscopically_reversible" : self.microscopically_reversible(),
 234		# }
 235
 236		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
 237			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
 238
 239	def from_file( self, path : Path ) :
 240		with open( Path / "am_config.json", "r" ) as f:
 241			config = json.load(f)
 242		self.from_dict()
 243
 244	#-------------------------------------------------------#
 245	#             Setters and State Management              #
 246	#-------------------------------------------------------#
 247	
 248	def _invalidate(self) :
 249
 250		"""
 251		Invalidate
 252
 253		Reset all derived properties so they will be recomputed when requested later
 254		"""
 255
 256		self.complexity = {}
 257		self.T = None
 258		self.T_x = None
 259		self.pi = None
 260		self.pi_fractional = None
 261		self.msp = None 
 262		self.reverse_am = None
 263		self.is_q_weighted = False
 264		self.is_minimal = False
 265
 266	def set_states( self, states : list[CausalState] ) :
 267		self._invalidate()
 268		self.states = states.copy()
 269		self.state_idx_map = {}
 270		for idx, state in enumerate( self.states ) :
 271			self.state_idx_map[ state.name ] = idx
 272
 273	def set_alphabet( self, alphabet : list[str] ) :
 274
 275		self._invalidate()
 276
 277		old_alphabet = self.alphabet.copy()
 278
 279		aSet = set()
 280		aSet.update( alphabet )
 281
 282		self.alphabet = sorted(list(aSet))
 283
 284		self.symbol_idx_map = {}
 285		for idx, symbol in enumerate( self.alphabet ) :
 286			self.symbol_idx_map[ symbol ] = idx
 287
 288		for i, tr in enumerate( self.transitions ) :
 289			symbol = old_alphabet[ tr.symbol_idx ]
 290			self.transitions[ i ] = Transition(
 291				origin_state_idx=tr.origin_state_idx,
 292				target_state_idx=tr.target_state_idx,
 293				prob=tr.prob,
 294				symbol_idx=self.symbol_idx_map[ symbol ]
 295			)
 296
 297	def set_transitions( self, transitions : list[Transition] ) :
 298		self._invalidate()
 299		self.transitions = transitions.copy()
 300
 301	def extend_states( self, states : list[CausalState] ) :
 302		self.set_states( self.states + states )
 303
 304	def extend_alphabet( self, alphabet : list[str] ) :
 305		self.set_alphabet( self.alphabet + alphabet  )
 306
 307	def extend_transitions( self, transitions : list[Transition] ) :
 308		self.set_transitions( self.transitions + transitions  )
 309
 310	def get_complexity_measure_if_exists(self, measure ) :
 311		m = self.complexity.get( measure, None )
 312		return m
 313
 314	def set_complexity_measure(self, measure, value ) :
 315		self.complexity[ measure ] = value
 316
 317	#-------------------------------------------------------#
 318	#                   Get and Compute                     #
 319	#-------------------------------------------------------#
 320
 321	def get_complexities( 
 322		self, 
 323		with_block_convergence=False, 
 324		with_reversal=False ) :
 325
 326		directly_calculable = [
 327			self.C_mu,
 328			self.h_mu,
 329			self.H_1,
 330			self.rho_mu
 331		]
 332
 333		requires_block_convergence = [
 334			self.E, 
 335			self.T_inf,
 336			self.S,
 337			self.Chi
 338		]
 339
 340		requires_reversal = [
 341			self.q_mu,
 342			self.r_mu,
 343			self.b_mu
 344		]
 345			
 346		complexities = { m.__name__ : m() for m in directly_calculable }
 347
 348		if with_block_convergence :
 349
 350			complexities |= { m.__name__ : m() for m in requires_block_convergence }
 351
 352			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
 353				if key in self.complexity :
 354					complexities[ key ] = self.complexity[ key ]
 355
 356		if with_reversal :
 357			complexities |= { m.__name__ : m() for m in requires_reversal }
 358
 359		return complexities
 360
 361	#-------------------------------------------------------#
 362
 363	def get_metadata(self) :
 364		return {
 365			"name" : self.name,
 366			'complexity' : self.complexity,
 367			"description" : self.description
 368		}
 369
 370	def get_transition_matrix(self) :
 371
 372		if self.T  is not None :
 373			return self.T
 374
 375		n_states = len( self.states )
 376		T = np.zeros((n_states, n_states))
 377
 378		for tr in self.transitions :    
 379			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
 380
 381		self.T = T
 382
 383		return self.T
 384
 385	#-------------------------------------------------------#
 386
 387	def get_T_X(self) :
 388
 389		if self.T_x  is not None :
 390			return self.T_x
 391
 392		n_states  = len( self.states )
 393		n_symbols = len( self.alphabet )
 394
 395		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
 396
 397		for tr in self.transitions :
 398			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
 399
 400		self.T_x = T_x
 401		return self.T_x
 402
 403	#-------------------------------------------------------#
 404
 405	def get_msp_qw(
 406		self,
 407		exact_state_cap: int = 1000,
 408		verbose: bool = True,
 409	):
 410		if self.msp is not None:
 411			return self.msp
 412
 413		try : 
 414
 415			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
 416
 417			self.msp = compute_msp_exact(
 418				T_x=self.get_Tx_fractional(),
 419				pi=self.get_fractional_stationary_distribution(),
 420				n_states=len(self.states),
 421				alphabet=self.alphabet,
 422				exact_state_cap=1000,
 423				bool = True
 424			)
 425
 426			return self.msp 
 427
 428		except RuntimeError as e :
 429			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
 430
 431		return self.get_msp()
 432
 433	def get_msp(
 434		self,
 435		exact_state_cap: int = 175_000,
 436		jsd_eps:         float = 1e-7,
 437		k_ann:           int   = 50,
 438		verbose                = True,
 439	) :
 440
 441		if self.msp is not None:
 442			return self.msp
 443	 
 444		T_x = self.get_T_X()
 445		pi  = self.get_stationary_distribution()
 446
 447		T_stacked      = np.stack(T_x)
 448		n_symbols      = len(self.alphabet)
 449		n_input_states = T_stacked.shape[1]
 450
 451		print( "\nComputing Mixed State Presentation\n" )
 452
 453		self.msp = compute_msp( 
 454			T_x=T_x,
 455			pi=pi,
 456			n_states=len(self.states),
 457			alphabet=self.alphabet,
 458			exact_state_cap=exact_state_cap,
 459			verbose=verbose
 460		)
 461
 462		return self.msp
 463
 464	def get_reverse_am(self) :
 465
 466		if self.reverse_am is not None:
 467			return self.reverse_am
 468
 469		pi = self.get_stationary_distribution()
 470		self.reverse_am = copy.deepcopy(self)
 471		
 472		new_transitions = []
 473		for tr in self.transitions:
 474			i = tr.target_state_idx
 475			j = tr.origin_state_idx
 476			
 477			p_reversed = (pi[j] * tr.prob) / pi[i]
 478			
 479			new_transitions.append(
 480				Transition(
 481					origin_state_idx=i,
 482					target_state_idx=j,
 483					prob=p_reversed,
 484					symbol_idx=tr.symbol_idx
 485				)
 486			)
 487
 488		self.reverse_am.set_transitions(new_transitions)
 489
 490		if self.reverse_am.is_epsilon_machine():
 491			return self.reverse_am
 492
 493		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
 494
 495		self.reverse_am.set_states( rmsp.states )
 496		self.reverse_am.set_transitions( rmsp.transitions )
 497		self.reverse_am.msp = rmsp
 498		self.reverse_am.start_state = 0
 499
 500		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
 501		self.reverse_am.minimize()
 502
 503		return self.reverse_am
 504
 505	#-------------------------------------------------------#
 506
 507	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
 508
 509		self.to_q_weighted()
 510
 511		n_states  = len( self.states )
 512		n_symbols = len( self.alphabet )
 513
 514		T_x = []
 515
 516		for x in range( n_symbols ) :
 517			T_x.append( [] )
 518			for i in range( n_states ) :
 519				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
 520
 521		for tr in self.transitions :
 522			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
 523
 524		return T_x
 525
 526	def get_T_sympy( self ) :
 527
 528		self.to_q_weighted()
 529
 530		n = len( self.states )
 531		T = sympy.zeros( n, n )
 532
 533		for tr in self.transitions :
 534			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
 535
 536		return T
 537
 538	def get_fractional_stationary_distribution(self) :
 539
 540		T = self.get_T_sympy()
 541
 542		if self.pi_fractional is not None :
 543			return self.pi_fractional
 544
 545		G = self.as_digraph()
 546
 547		if not nx.is_strongly_connected(G):
 548			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 549
 550		self.pi_fractional = solve_for_pi_fractional( T )
 551
 552		return self.pi_fractional
 553
 554	def get_stationary_distribution(self):
 555
 556		if self.pi is not None :
 557			return self.pi
 558
 559		G = self.as_digraph()
 560		
 561		if not nx.is_strongly_connected(G):
 562			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 563
 564		T = self.get_transition_matrix()
 565		return solve_for_pi( T )		
 566
 567	#-------------------------------------------------------#
 568	#                 Complexity Measures                   #
 569	#-------------------------------------------------------#
 570	
 571	def C_mu( self ) :
 572
 573		"""
 574		c_mu (statistical complexity, or forecasting complexity)
 575
 576		Interpretations
 577			- The amount of historical information a process stores.
 578			- The amount of structure in a process.
 579		"""
 580
 581		m = self.get_complexity_measure_if_exists( "C_mu" )
 582
 583		if m is not None :
 584			return m
 585
 586		pi = self.get_stationary_distribution()
 587
 588		h = 0
 589		for i, pr in enumerate( pi ) :
 590			
 591			if pr < self.EPS :
 592				continue
 593
 594			h += -pr * np.log2( pr )
 595
 596		self.set_complexity_measure( "C_mu", h )
 597
 598		return h
 599
 600	#-------------------------------------------------------#
 601
 602	def h_mu( self ) :
 603
 604		"""
 605		h_mu (entropy rate)
 606
 607		Interpretation
 608			- Intrinsic Randomness in the process.
 609		"""
 610
 611		m = self.get_complexity_measure_if_exists( "h_mu" )
 612
 613		if m is not None :
 614			return m
 615
 616		T  = self.get_transition_matrix()
 617		pi = self.get_stationary_distribution()
 618
 619		n_states = pi.size
 620
 621		h = 0
 622		for i, pr in enumerate( pi ) :
 623
 624			if pr < self.EPS :
 625				continue
 626
 627			row_entropy = 0
 628			for j in range( len( pi ) ) :
 629
 630				if T[ i, j ]  < self.EPS :
 631					continue
 632
 633				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
 634
 635			h += pr * row_entropy
 636
 637		self.set_complexity_measure( "h_mu", h )
 638
 639		return h
 640
 641	#-------------------------------------------------------#
 642
 643	def H_1(self):
 644
 645		"""
 646		H_1 (single-symbol entropy, or marginal entropy)
 647		Interpretations:
 648			- How uncertain you are about a single measurement with no context.
 649			- The total information content per symbol before seeing any neighbors.
 650			- The maximum possible entropy rate, achieved by a memoryless process.
 651		"""
 652
 653		m = self.get_complexity_measure_if_exists("H_1")
 654		if m is not None:
 655			return m
 656
 657		pi  = self.get_stationary_distribution()
 658		T_X = self.get_T_X()  # dict: symbol -> matrix
 659
 660		h = 0.0
 661		for T_x in T_X:
 662			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
 663			p_sym = 0.0
 664			for i, pr in enumerate(pi):
 665				if pr < self.EPS:
 666					continue
 667				p_sym += pr * T_x[i, :].sum()
 668
 669			if p_sym < self.EPS:
 670				continue
 671			h -= p_sym * np.log2(p_sym)
 672
 673		self.set_complexity_measure("H_1", h)
 674		return h
 675
 676	#-------------------------------------------------------#
 677
 678	def rho_mu(self) :
 679		
 680		"""
 681		rho_mu (anticipated information, or single-symbol redundancy)
 682		Interpretations:
 683			- How much the past reduces your uncertainty about the next symbol.
 684			- The portion of each symbol's information that was already predictable.
 685			- The difference between naive per-symbol uncertainty and true randomness.
 686			- Equal to H_1 - h_mu: the gap closed by knowing context.
 687		"""
 688
 689		m = self.get_complexity_measure_if_exists("rho_mu")
 690		
 691		if m is not None:
 692			return m
 693
 694		rho = self.H_1() - self.h_mu()
 695		
 696		self.set_complexity_measure("rho_mu", rho)
 697		
 698		return rho
 699
 700	#-------------------------------------------------------#
 701
 702	def block_convergence( self ) :
 703
 704		"""
 705		block_convergence
 706			- Estimates a range of complexity measures using block entropy convergence.
 707			- E, S, T, T_L, H_L, h_mu_L, H_sync
 708		"""
 709
 710		trs = [ [] for _ in range( len( self.states ) ) ]
 711		for tr in self.transitions :
 712			trs[ tr.origin_state_idx ].append( ( 
 713				tr.symbol_idx, 
 714				float( tr.prob ),
 715				tr.target_state_idx ) )
 716
 717		pi = self.get_stationary_distribution()
 718
 719		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
 720		branches = [(1.0, list(state_dist))]
 721
 722		print( "\nComputing Block Entropy\n" )
 723
 724		C = am_fast.block_entropy_convergence(
 725			h_mu            = self.h_mu(),
 726			n_states        = len( self.states ),
 727			n_symbols       = len( self.alphabet ),
 728			convergence_tol = 1e-6,
 729			precision       = 10,
 730			eps             = 1e-25,
 731			branches        = branches,
 732			trans           = trs,
 733			max_branches    = 30_000_000
 734		)
 735
 736		print( "Done\n" )
 737
 738		self.set_complexity_measure( f"E",          C.E )
 739		self.set_complexity_measure( f"S",          C.S )
 740		self.set_complexity_measure( f"T_inf",      C.T )
 741		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
 742		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
 743		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
 744		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
 745
 746		return C
 747
 748	#-------------------------------------------------------#
 749
 750	def E( self ) :
 751
 752		"""
 753		E (excess entropy)
 754
 755		Interpretations:
 756			- Mutual information from past to future I(Past;Future).
 757			- How much total information from the past reduces uncertainty in the future
 758			- Amount of (apparent) stored information
 759
 760		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 761		https://arxiv.org/abs/1309.3792
 762		"""
 763
 764		m = self.get_complexity_measure_if_exists( "E" )
 765
 766		if m is not None :
 767			return m
 768
 769		try : 
 770			msp = self.get_msp()
 771			E, S, T = msp.get_E_S_T()
 772			self.set_complexity_measure( "E", E )
 773			self.set_complexity_measure( "S", S )
 774			self.set_complexity_measure( "T_inf", T )
 775			
 776		except Exception as e :
 777
 778			print( f"MSP failed {e}" )
 779			exit()
 780
 781			C = self.block_convergence()	
 782			E = C.E
 783			self.set_complexity_measure( "E", E )
 784
 785		return E
 786
 787	#-------------------------------------------------------#
 788
 789	def S( self ) :
 790
 791		"""
 792		S (synchronization information)
 793
 794		Interpretation
 795			-synchronization information
 796
 797		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 798		https://arxiv.org/abs/1309.3792
 799		"""
 800
 801		m = self.get_complexity_measure_if_exists( "S" )
 802
 803		if m is not None :
 804			return m
 805
 806		try : 
 807			msp = self.get_msp()
 808			E, S, T = msp.get_E_S_T()
 809			self.set_complexity_measure( "E", E )
 810			self.set_complexity_measure( "S", S )
 811			self.set_complexity_measure( "T_inf", T )
 812
 813		except Exception as e :
 814			print( f"{e} \nFalling back to iterative estimation.")
 815			exit()
 816
 817			C = self.block_convergence()	
 818			S = C.S
 819			self.set_complexity_measure( "S", S )
 820
 821		return S
 822
 823	#-------------------------------------------------------#
 824
 825	def T_inf( self ) :
 826
 827		"""
 828		T_inf (transient information)
 829
 830		Interpretation
 831			-Transient information
 832
 833		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
 834		https://arxiv.org/abs/1309.3792
 835		"""
 836
 837		m = self.get_complexity_measure_if_exists( "T_inf" )
 838
 839		if m is not None :
 840			return m
 841
 842		try : 
 843			msp = self.get_msp()
 844			E, S, T = msp.get_E_S_T()
 845			self.set_complexity_measure( "E", E )
 846			self.set_complexity_measure( "S", S )
 847			self.set_complexity_measure( "T_inf", T )
 848
 849		except Exception as e :
 850			print( f"{e} \nFalling back to iterative estimation.")
 851			exit()
 852			C = self.block_convergence()	
 853			T_inf = C.T
 854			self.set_complexity_measure( "T_inf", T_inf )
 855
 856		return T_inf
 857
 858	#-------------------------------------------------------#
 859
 860	def Chi( self ) :
 861
 862		"""
 863		Chi (crypticity)
 864
 865		Interpretation
 866			- Difference between internal stored information and apparent information to an observer
 867			- How muching information is hiding in the system
 868		"""
 869
 870		m = self.get_complexity_measure_if_exists( "Chi" )
 871
 872		if m is not None :
 873			return m
 874
 875		Chi = self.C_mu() - self.E()
 876
 877		self.set_complexity_measure( "Chi", Chi )
 878
 879		return Chi
 880
 881	#-------------------------------------------------------#
 882
 883	def k_X( self ) :
 884
 885		"""
 886		k_x (cryptic order)
 887
 888		Interpretation
 889			
 890		"""
 891
 892		m = self.get_complexity_measure_if_exists( "k_X" )
 893
 894		if m is not None :
 895			return m
 896
 897		raise NotImplementedError("Not implemented")
 898
 899	#-------------------------------------------------------#
 900
 901	def R( self ) :
 902
 903		"""
 904		R (Markov order)
 905		Interpretations:
 906			- The smallest L such that h_mu(L) = h_mu.
 907			- The context length beyond which additional history gives no predictive benefit.
 908			- The memory length of the process — how far back correlations extend.
 909			- The ideal context window length for a predictor.
 910			- R = 0 means the process is IID (memoryless).
 911			- R = infinity means the process has unbounded memory (non-finitary).
 912			— only report if convergence was clean, otherwise None
 913		"""
 914
 915		m = self.get_complexity_measure_if_exists( "R" )
 916
 917		if m is not None :
 918			return m
 919
 920		raise NotImplementedError("Not implemented")
 921
 922	#-------------------------------------------------------#
 923
 924	def b_mu( self ) :
 925
 926		"""
 927		b_mu - requires reverse HMM
 928
 929
 930		Interpretation
 931			
 932		"""
 933
 934		m = self.get_complexity_measure_if_exists( "b_mu" )
 935
 936		if m is not None :
 937			return m
 938
 939		raise NotImplementedError("Not implemented")
 940
 941	#-------------------------------------------------------#
 942
 943	def r_mu( self ) :
 944
 945		"""
 946		r_mu - requires reverse HMM
 947
 948
 949		Interpretation
 950			
 951		"""
 952
 953		m = self.get_complexity_measure_if_exists( "r_mu" )
 954
 955		if m is not None :
 956			return m
 957
 958		raise NotImplementedError("Not implemented")
 959
 960	#-------------------------------------------------------#
 961
 962	def q_mu( self ) :
 963
 964		"""
 965		q_mu 
 966
 967		Interpretation
 968		"""
 969
 970		m = self.get_complexity_measure_if_exists( "q_mu" )
 971
 972		if m is not None :
 973			return m
 974
 975		raise NotImplementedError("Not implemented")
 976
 977	#-------------------------------------------------------#
 978	#                      Properties                       #
 979	#-------------------------------------------------------#
 980
 981	def causally_reversible( self ) :
 982
 983		"""
 984		causally_reversible
 985		"""
 986
 987		m = self.get_complexity_measure_if_exists( "causally_reversible" )
 988
 989		if m is not None :
 990			return m
 991
 992		raise NotImplementedError("Not implemented")
 993
 994	#-------------------------------------------------------#
 995
 996	def microscopically_reversible( self ) :
 997
 998		"""
 999		microscopically_reversible
1000		"""
1001
1002		m = self.get_complexity_measure_if_exists( "microscopically_reversible" )
1003
1004		if m is not None :
1005			return m
1006
1007		raise NotImplementedError("Not implemented")
1008
1009	#-------------------------------------------------------#
1010
1011	def is_row_stochastic(self) :
1012
1013		"""
1014		is_row_stochastic
1015			-All states have outgoing transition probabilities which sum 1
1016		"""
1017
1018		sums = np.zeros( len( self.states ) )
1019		for tr in self.transitions :
1020			sums[ tr.origin_state_idx ] += tr.prob
1021		return np.allclose( sums, 1.0 )
1022
1023	#-------------------------------------------------------#
1024
1025	def is_unifilar(self) :
1026
1027		"""
1028		is_unifilar
1029			-All states have no more than one outgoing transition per-symbol
1030		"""
1031
1032		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
1033
1034		for tr in self.transitions : 
1035		
1036			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
1037				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
1038		
1039			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
1040				return False
1041		
1042		return True
1043
1044	#-------------------------------------------------------#
1045
1046	def is_strongly_connected(self) :
1047
1048		"""
1049		is_strongly_connected
1050			-Every state is reachable from every other state
1051		"""
1052
1053		return nx.is_strongly_connected( self.as_digraph() )
1054
1055	#-------------------------------------------------------#
1056
1057	def is_aperiodic(self) :
1058
1059		"""
1060		is_aperiodic
1061			-"A strongly connected directed graph is aperiodic 
1062				if there is no integer k > 1 that divides the length of every cycle in the graph."
1063		"""
1064
1065		return nx.is_aperiodic( self.as_digraph() )
1066
1067	#-------------------------------------------------------#
1068
1069	def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) :
1070
1071		with_probs = not topological_only
1072
1073		# Construct the DFA
1074		dfa = self.as_dfa( with_probs=with_probs )
1075
1076		# Minimize the DFA
1077		#dfa = dfa.minify(retain_names=True)
1078		dfa = am_fast.minify_cpp( dfa, retain_names=True )
1079
1080		# check we have minimal number of states
1081		if len( dfa.states ) != len( self.states ) :
1082			if verbose : 
1083				print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" )
1084			return False
1085
1086		return True
1087
1088	def is_topological_epsilon_machine( self, verbose=True ) :
1089
1090		"""
1091		is_topological_epsilon_machine
1092			-see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024
1093		"""
1094
1095		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1096			if verbose : 
1097				print( f"Either non unifilar or not strongly connected" )
1098			return False
1099		else :
1100			return self.is_minimal_as_dfa( topological_only=True, verbose=verbose )
1101
1102	def is_epsilon_machine( self, verbose=True ) :
1103
1104		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1105			if verbose : 
1106				print( f"Either non unifilar or not strongly connected" )
1107			return False
1108		else :
1109			return self.is_minimal_as_dfa( topological_only=False, verbose=verbose )
1110
1111	#-------------------------------------------------------#
1112	#                      Properties                       #
1113	#-------------------------------------------------------#
1114
1115	def minimize(self, retain_names: bool = True):
1116
1117		"""
1118		minimize
1119			-Simplify to minimal form, using Myhill Nerode
1120		"""
1121
1122		if self.is_minimal :
1123			return
1124
1125		start = time.perf_counter()
1126
1127		if not self.is_unifilar():
1128			raise ValueError(
1129				"DFA minimization is not valid for non-unifilar HMMs"
1130			)
1131
1132		was_strongly_connected = self.is_strongly_connected()
1133
1134		was_row_stochastic = self.is_row_stochastic()
1135		n_states_before = len(self.states)
1136
1137		dfa = self.as_dfa(with_probs=True)
1138
1139		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1140		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1141
1142		# Build lookup from original state index -> CausalState object
1143		orig_state   = {i: s for i, s in enumerate(self.states)}
1144		eq_list      = list(min_dfa.states)
1145
1146		start_eq = min_dfa.initial_state
1147		
1148		# Separate the start state, then sort the rest by the 
1149		# smallest original state index inside each equivalence class.
1150		other_eqs = [eq for eq in eq_list if eq != start_eq]
1151		other_eqs.sort(key=lambda eq: min(eq))
1152
1153		# Recombine so start eq comes first, followed by the sorted remaining classes
1154		eq_list = [start_eq] + other_eqs
1155		# ----------------------------------------------------------
1156
1157		# Recompute eq_to_idx with the new ordering
1158		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1159
1160		# new_start is now guaranteed to be 0
1161		new_start = 0
1162
1163		# Map each original state index -> its equivalence class
1164		# Guard: minify() silently drops unreachable states
1165		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1166
1167		# Build lookup from original state index -> its transitions
1168		orig_trs = defaultdict(list)
1169		for t in self.transitions:
1170			orig_trs[t.origin_state_idx].append(t)
1171
1172		new_trs = []
1173		for eq in min_dfa.states:
1174			rep        = next(iter(eq))
1175			origin_idx = eq_to_idx[eq]
1176			for t in orig_trs[rep]:
1177				target_eq  = orig_to_eq[t.target_state_idx]
1178				target_idx = eq_to_idx[target_eq]
1179				new_trs.append(Transition(
1180					origin_state_idx = origin_idx,
1181					target_state_idx = target_idx,
1182					prob             = t.prob,
1183					symbol_idx       = t.symbol_idx,
1184				))
1185
1186		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1187
1188		# Compute new names
1189		if retain_names:
1190			new_names = [
1191				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1192				else members[0].name
1193				for members in members_list
1194			]
1195		else:
1196			new_names = [str(j) for j in range(len(eq_list))]
1197
1198		old_name_to_new_name = {
1199			m.name: new_names[j]
1200			for j, members in enumerate(members_list)
1201			for m in members
1202		}
1203
1204		# Build the new states, preserving classes and isomorphs regardless of naming
1205		new_states = []
1206		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1207			classes   = set().union(*(m.classes for m in members))
1208			isomorphs = {
1209				old_name_to_new_name.get(iso, iso)
1210				for m in members
1211				for iso in m.isomorphs
1212				if old_name_to_new_name.get(iso, iso) != name
1213			}
1214			new_states.append(CausalState(
1215				name      = name,
1216				classes   = classes,
1217				isomorphs = isomorphs,
1218			))
1219
1220		self.set_states(new_states)
1221		self.set_transitions(new_trs)
1222		self.start_state = new_start
1223
1224		if n_states_before == len(new_states) :
1225			print( f"{n_states_before} state HMM was already minimal." )
1226		else :
1227			print( f"Minimized from {n_states_before} to {len(new_states)}" )
1228
1229		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1230			raise RuntimeError(
1231				f"Minimization broke strongly connected"
1232			)
1233
1234		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1235			raise RuntimeError(
1236				f"Minimization broke row stochasticity"
1237			)
1238
1239		self.is_minimal = True
1240
1241	#-------------------------------------------------------#
1242	#                      Modifiers                        #
1243	#-------------------------------------------------------#
1244
1245
1246	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1247
1248		# get equivalent networkx graph
1249		G = self.as_digraph()
1250
1251		# if already strongly connected, nothing to do
1252		if not nx.is_strongly_connected( G ) :
1253
1254			start = time.perf_counter()
1255			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1256
1257			# decompose into strongly connected components and sort by length
1258			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1259			subgraph_nodes.sort(key=len)
1260			component_state_set = subgraph_nodes[-1]
1261
1262			# Take the largest strongly connected component (as list of state names)
1263			component_states = sorted( list( component_state_set ) )
1264
1265			# make temporary copies of the old transitions and states
1266			old_transitions = [
1267				Transition(
1268					origin_state_idx=tr.origin_state_idx,
1269					target_state_idx=tr.target_state_idx,
1270					prob=tr.prob,
1271					symbol_idx=tr.symbol_idx
1272				)
1273
1274				for tr in self.transitions
1275			]
1276
1277			old_states = [
1278				CausalState(
1279					name=s.name,
1280					classes=s.classes,
1281					isomorphs=s.isomorphs
1282				)
1283				for s in self.states
1284			]
1285
1286			self.set_states(
1287				states=[ 
1288					state
1289					for i, state in enumerate( old_states ) if i in component_state_set
1290				]
1291			)
1292
1293			# we will build new transition list based on those belonging to the component
1294			self.set_transitions( transitions= [] )
1295
1296			# for tracking which new transitions leave each state
1297			transitions_from_state = { state : set() for state in component_states }
1298			new_transitions = []
1299
1300			for tr in old_transitions :
1301
1302				origin_state_name = old_states[ tr.origin_state_idx ].name
1303				target_state_name = old_states[ tr.target_state_idx ].name
1304
1305				# skip transitions that connect separate strongly connected components
1306				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1307					continue
1308
1309				# track transitions (by index in new transitions list) that leave this state
1310				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1311
1312				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1313				my_target_state_idx = self.state_idx_map[ target_state_name ]
1314
1315				new_transitions.append( 
1316					Transition(
1317						origin_state_idx=my_origin_state_idx,
1318						target_state_idx=my_target_state_idx,
1319						prob=tr.prob,
1320						symbol_idx=tr.symbol_idx
1321					)  )
1322
1323			self.set_transitions( transitions=new_transitions )
1324
1325			# if we removed an outgoing transition from a state, we need to distribute its probability 
1326			# among the remaining outgoing transitions from the state
1327			for state in component_states :
1328				
1329				# get the set of transitions leaving this state
1330				state_trs = transitions_from_state[ state ]
1331
1332				# sum the probabilities of the outgoing transitions from the state
1333				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1334
1335				# how much probability is missing
1336				diff = 1.0 - p_sum
1337
1338				# if significant difference
1339				if abs( diff ) > self.EPS : 
1340
1341					# calculate how much of the difference each transition gets
1342					adjustment = diff / len( state_trs )
1343					
1344					# update the transitions
1345					for i in state_trs :
1346
1347						# transition with origional probability
1348						tr = self.transitions[ i ]
1349
1350						# adjusted probability
1351						self.transitions[ i ] = Transition(
1352							origin_state_idx=tr.origin_state_idx, 
1353							target_state_idx=tr.target_state_idx, 
1354							prob=tr.prob + adjustment, 
1355							symbol_idx=tr.symbol_idx )
1356
1357			if rename_states :
1358				self.set_states( [
1359					CausalState( name=f"{i}" )
1360					for i, s in enumerate( self.states )
1361				] )
1362
1363
1364	def to_q_weighted( self, denominator_limit=1000 ) :
1365
1366		"""
1367		to_q_weighted
1368			-Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq
1369		"""
1370
1371		if self.is_q_weighted :
1372			return
1373
1374		if not self.is_row_stochastic() :
1375			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1376
1377		t_from = [[] for _ in range(len(self.states))]
1378
1379		for i, tr in enumerate( self.transitions ) :
1380			t_from[ tr.origin_state_idx ].append( i )
1381
1382		new_transitions = []
1383		for t_list in t_from :
1384			
1385			if not t_list : 
1386				continue
1387
1388			p_q_sum = Fraction(0,1)
1389			p_qs = []
1390
1391			for t_idx in t_list :
1392			
1393				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1394				p_q_sum += p_q
1395				p_qs.append( p_q )
1396
1397			if p_q_sum != Fraction(1,1) :
1398				
1399				max_pq_i = np.argmax( p_qs )
1400				max_oq = p_qs[ max_pq_i ]
1401
1402				diff = p_q_sum - Fraction(1,1)
1403
1404				# recurse with higher resolution
1405				if diff > max_oq :
1406					return self.to_q_weighted( denominator_limit*10 )
1407				else :
1408					p_qs[ max_pq_i ] -= diff
1409
1410			for i, t_idx in enumerate( t_list ) :	
1411				new_transitions.append( 
1412					Transition( 
1413						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1414						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1415						prob=float(p_qs[ i ]),
1416						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1417						pq=p_qs[ i ]
1418					)
1419				)
1420
1421		self.set_transitions( new_transitions )
1422		self.is_q_weighted = True
1423
1424	#-------------------------------------------------------#
1425	#                    Data Generation                    #
1426	#-------------------------------------------------------#
1427
1428	def generate_data(
1429		self,
1430		file_prefix: str,
1431		n_gen: int,
1432		include_states: bool,
1433		isomorphic_shifts : set[int]=None,
1434		random_seed : int=42 ) -> dict[any] : 
1435
1436		trs = [ [] for _ in range( len( self.states ) ) ]
1437		for tr in self.transitions :
1438			trs[ tr.origin_state_idx ].append( ( 
1439				tr.symbol_idx, 
1440				float( tr.prob ),
1441				tr.target_state_idx ) )
1442
1443		data = am_fast.generate_data(
1444			n_gen=n_gen,
1445			start_state=self.start_state,
1446			transitions=trs,
1447			alphabet=sorted(list(self.alphabet)),
1448			include_states=include_states,
1449			random_seed=random_seed
1450		)
1451
1452		if isomorphic_shifts is not None :
1453
1454			if not include_states :
1455				raise ValueError( "Isomorphic inversion requires include_states=True" )
1456
1457			data[ "isomorphic_shifts" ] = {}
1458
1459			for shift in isomorphic_shifts :
1460
1461				shifted = HMM.isomorphic_shift(
1462					symbol_indices=data[ "symbol_index" ], 
1463					state_indices=data[ "state_index" ], 
1464					m=self,
1465					shift=shift
1466				)
1467
1468				data[ "isomorphic_shifts" ][ shift ] = {
1469					"symbol_index" : shifted[ "symbol_index" ],
1470					"state_index"  : shifted[ "state_index" ]
1471				}
1472
1473		am_fast.save_data(
1474			data=data,
1475			file_prefix=file_prefix,
1476			alphabet=sorted(list(self.alphabet)),
1477			n_states=len( self.states ),
1478			start_state=self.start_state,
1479			random_seed=random_seed,
1480			metadata=self.get_metadata() )
1481
1482		return data
1483
1484	#-------------------------------------------------------#
1485	#                 Basic Visualization                   #
1486	#-------------------------------------------------------#
1487
1488	def draw_graph(
1489		self,
1490		engine     : str  = 'dot',
1491		output_dir : Path = Path('.'),
1492		show       : bool = True
1493	) -> None :
1494
1495		G = self.as_digraph()
1496
1497		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1498		
1499		am_vis.draw_graph( 
1500			self, 
1501			output_dir=output_dir, 
1502			title="am_graph", 
1503			view=show, 
1504			subgraphs=subgraphs, 
1505			engine=engine )
1506
1507	#-------------------------------------------------------#
1508	#                   Alternate Forms                     #
1509	#-------------------------------------------------------#
1510
1511
1512	def as_digraph( self ) :
1513
1514		"""
1515		Build nx DiGraph based on the HMM, without edge probabilities or symbols
1516		"""
1517
1518		G = nx.DiGraph()
1519		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1520
1521		for tr in self.transitions :
1522			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1523
1524		return G
1525
1526	def as_dfa( self, with_probs : bool ) :
1527
1528		precision=8
1529
1530		def edge_label( symb, prob ) :
1531			return f"({symb},{round(prob, precision)})"
1532
1533		# Build states, symbols, and transitions 
1534		dfa_states  = { i for i, _ in enumerate( self.states ) }
1535			
1536		if not with_probs :
1537			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1538		else : 
1539			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1540
1541		dfa_transitions = defaultdict(dict)
1542
1543		if not with_probs :
1544			for t in self.transitions :
1545				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1546		else :
1547			for t in self.transitions :
1548				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1549
1550		# Construct the DFA
1551		return DFA(
1552			states=dfa_states,
1553			input_symbols=dfa_symbols,
1554			transitions=dfa_transitions,
1555			initial_state=self.start_state,
1556			allow_partial=True,
1557			final_states={ 
1558				s for s in dfa_states
1559			}
1560		)

Helper class that provides a standard way to create an ABC using inheritance.

HMM( states: list[amachine.am_causal_state.CausalState] | None = None, transitions: list[amachine.am_transition.Transition] | None = None, start_state: int = 0, alphabet: list[str] | None = None, isoclasses: dict[str, set[str]] | None = None, name: str = '', description: str = '')
 98	def __init__( 
 99		self,
100		states : list[CausalState] | None = None,
101		transitions : list[Transition] | None = None,
102		start_state : int = 0,
103		alphabet : list[str] | None = None,
104		isoclasses : dict[str,set[str]] | None = None,
105		name : str = "",
106		description : str = "" ) : 
107
108		self.alphabet    = alphabet or []
109		self.states      = states or []
110		self.transitions = transitions or []
111
112		self.set_alphabet( self.alphabet )
113		self.set_states( self.states )
114		self.set_transitions( self.transitions )
115
116		self.start_state = start_state
117		self.isoclasses  = isoclasses or {}
118
119		self.name = name
120		self.description = description
121
122		# to be depreciated
123		self.isoclass = None
124
125		# --- derived --------
126
127		self.complexity : dict[str, any] = {}
128
129		self.symbol_idx_map : dict[str,int] = {}
130		self.state_idx_map  : dict[str,int] = {}
131
132		self.pi_fractional = None
133		self.pi = None
134		self.T  = None
135		self.msp : MSP = None
136		self.reverse_am : HMM = None
137		self.is_q_weighted = False
138		self.is_minimal = False
139
140		# --- const ----------
141
142		self.EPS = 1e-12
@staticmethod
def isomorphic_shift( input_symbol_indices: numpy.ndarray, input_state_indices: numpy.ndarray, m: HMM, shift: int = 1) -> dict[str, numpy.ndarray]:
40	@staticmethod
41	def isomorphic_shift(
42		input_symbol_indices: np.ndarray,
43		input_state_indices:  np.ndarray,
44		m : HMM,
45		shift : int = 1
46	) -> dict[str, np.ndarray]:
47
48		if not any( state.isomorphs for state in m.states ):
49			raise ValueError("HMM has no states with isomorphs")
50
51		inputs = np.asarray(input_symbol_indices)
52		states = np.asarray(input_state_indices)
53
54		n_states = len(m.states)
55
56		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
57
58		for tr in m.transitions:
59			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
60
61		# Build isomorph remapping: identity by default, overridden where isomorphs exist
62		iso_table = np.arange(n_states, dtype=np.int32)
63
64		for i, state in enumerate(m.states):
65			if state.isomorphs:
66				isomorph       = sorted(state.isomorphs)[0]
67				iso_table[ i ] = m.state_idx_map[isomorph]
68
69		for i, state in enumerate(m.states):
70			if state.isomorphs:
71
72				# extend the isomorph list to include the identity
73				isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] )
74
75				# find the identity index
76				pos = isormorphs_with_identity.index( i )
77
78				# cyclical shift 
79				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
80
81		origins = states[:-1]
82		targets = states[1:]
83
84		out_origins = iso_table[origins]
85		out_targets = iso_table[targets]
86
87		inv_sym = tr_sym_table[ out_origins, out_targets ]
88		inv_sts = np.empty(states.size, dtype=states.dtype)
89
90		inv_sts[:-1] = out_origins
91		inv_sts[-1]  = out_targets[-1]
92
93		return {
94			"symbol_index": inv_sym.astype(inputs.dtype),
95			"state_index":  inv_sts,
96		}
alphabet
states
transitions
start_state
isoclasses
name
description
isoclass
complexity: dict[str, any]
symbol_idx_map: dict[str, int]
state_idx_map: dict[str, int]
pi_fractional
pi
T
reverse_am: HMM
is_q_weighted
is_minimal
EPS
@override
def get_states(self) -> 'list[CausalStates]':
148	@override
149	def get_states(self) -> list[CausalStates] : 
150		return self.states
@override
def get_transitions(self) -> list[amachine.am_transition.Transition]:
152	@override
153	def get_transitions(self) -> list[Transition] : 
154		return self.transitions
@override
def get_alphabet(self) -> list[amachine.am_symbol.Symbol]:
156	@override
157	def get_alphabet(self) -> list[Symbol]:
158		return [ Symbol(a) for a in self.alphabet ]
def to_dict(self):
165	def to_dict(self) :
166		return {
167			"name"            : self.name,
168			"description"     : self.description,
169			"states"          : [ asdict(state)      for state      in self.states      ],
170			"transitions"     : [ asdict(transition) for transition in self.transitions ],
171			"alphabet"        : self.alphabet
172		}
def from_dict(self, config):
174	def from_dict( self, config ) :
175
176		self.name            = config.name 
177		self.description     = config.description 
178		
179		self.set_states( states=[ 
180			CausalState( 
181				name=state[ "name" ],
182				classes=set( state[ "classes" ] )
183			)
184			for state in config.states
185		] )
186
187		self.set_transitions( transitions=[ 
188			Transition( 
189				origin_state_idx=tr[ "origin_state_idx" ],
190				target_state_idx=tr[ "target_state_idx" ],
191				prob=tr[ "prob" ],
192				symbol_idx=tr[ "symbol_idx" ]
193			)
194			for tr in config.transitions
195		] )
196
197		self.set_alphabet( alphabet=config.alphabet ) 
def save_config( self, path: pathlib.Path, with_complexity: bool = False, with_block_convergence: bool = False, with_structural_properties: bool = False, with_causal_properties: bool = False):
199	def save_config(
200		self, 
201		path : Path, 
202		with_complexity : bool = False, 
203		with_block_convergence :  bool = False,
204		with_structural_properties : bool = False,
205		with_causal_properties : bool = False ) :
206
207		config = self.to_dict()
208
209		if with_complexity :
210			
211			complexity = self.get_complexities( 
212				with_block_convergence=with_block_convergence,
213				with_reversal=False # not implemented
214			)
215
216			config[ "complexity" ] = complexity
217
218		config[ "structural_properties" ] = {
219			"unifilar"              : self.is_unifilar(),
220			"row_stochastic"        : self.is_row_stochastic(),
221			"strongly_connected"    : self.is_strongly_connected(),
222			"aperiodic"             : self.is_aperiodic(),
223			"minimal"               : self.is_minimal_as_dfa( topological_only=False ),
224			"topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ),
225			"is_epsilon_machine"    : self.is_epsilon_machine()
226		}
227
228		# Not implemented
229		# config[ "causal_properties" ] = {
230		# 	"k_X"                        : self.k_X(),
231		# 	"R"                          : self.R(),
232		# 	"causally_reversible"        : self.causally_reversible(),
233		# 	"microscopically_reversible" : self.microscopically_reversible(),
234		# }
235
236		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
237			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
def from_file(self, path: pathlib.Path):
239	def from_file( self, path : Path ) :
240		with open( Path / "am_config.json", "r" ) as f:
241			config = json.load(f)
242		self.from_dict()
def set_states(self, states: list[amachine.am_causal_state.CausalState]):
266	def set_states( self, states : list[CausalState] ) :
267		self._invalidate()
268		self.states = states.copy()
269		self.state_idx_map = {}
270		for idx, state in enumerate( self.states ) :
271			self.state_idx_map[ state.name ] = idx
def set_alphabet(self, alphabet: list[str]):
273	def set_alphabet( self, alphabet : list[str] ) :
274
275		self._invalidate()
276
277		old_alphabet = self.alphabet.copy()
278
279		aSet = set()
280		aSet.update( alphabet )
281
282		self.alphabet = sorted(list(aSet))
283
284		self.symbol_idx_map = {}
285		for idx, symbol in enumerate( self.alphabet ) :
286			self.symbol_idx_map[ symbol ] = idx
287
288		for i, tr in enumerate( self.transitions ) :
289			symbol = old_alphabet[ tr.symbol_idx ]
290			self.transitions[ i ] = Transition(
291				origin_state_idx=tr.origin_state_idx,
292				target_state_idx=tr.target_state_idx,
293				prob=tr.prob,
294				symbol_idx=self.symbol_idx_map[ symbol ]
295			)
def set_transitions(self, transitions: list[amachine.am_transition.Transition]):
297	def set_transitions( self, transitions : list[Transition] ) :
298		self._invalidate()
299		self.transitions = transitions.copy()
def extend_states(self, states: list[amachine.am_causal_state.CausalState]):
301	def extend_states( self, states : list[CausalState] ) :
302		self.set_states( self.states + states )
def extend_alphabet(self, alphabet: list[str]):
304	def extend_alphabet( self, alphabet : list[str] ) :
305		self.set_alphabet( self.alphabet + alphabet  )
def extend_transitions(self, transitions: list[amachine.am_transition.Transition]):
307	def extend_transitions( self, transitions : list[Transition] ) :
308		self.set_transitions( self.transitions + transitions  )
def get_complexity_measure_if_exists(self, measure):
310	def get_complexity_measure_if_exists(self, measure ) :
311		m = self.complexity.get( measure, None )
312		return m
def set_complexity_measure(self, measure, value):
314	def set_complexity_measure(self, measure, value ) :
315		self.complexity[ measure ] = value
def get_complexities(self, with_block_convergence=False, with_reversal=False):
321	def get_complexities( 
322		self, 
323		with_block_convergence=False, 
324		with_reversal=False ) :
325
326		directly_calculable = [
327			self.C_mu,
328			self.h_mu,
329			self.H_1,
330			self.rho_mu
331		]
332
333		requires_block_convergence = [
334			self.E, 
335			self.T_inf,
336			self.S,
337			self.Chi
338		]
339
340		requires_reversal = [
341			self.q_mu,
342			self.r_mu,
343			self.b_mu
344		]
345			
346		complexities = { m.__name__ : m() for m in directly_calculable }
347
348		if with_block_convergence :
349
350			complexities |= { m.__name__ : m() for m in requires_block_convergence }
351
352			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
353				if key in self.complexity :
354					complexities[ key ] = self.complexity[ key ]
355
356		if with_reversal :
357			complexities |= { m.__name__ : m() for m in requires_reversal }
358
359		return complexities
def get_metadata(self):
363	def get_metadata(self) :
364		return {
365			"name" : self.name,
366			'complexity' : self.complexity,
367			"description" : self.description
368		}
def get_transition_matrix(self):
370	def get_transition_matrix(self) :
371
372		if self.T  is not None :
373			return self.T
374
375		n_states = len( self.states )
376		T = np.zeros((n_states, n_states))
377
378		for tr in self.transitions :    
379			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
380
381		self.T = T
382
383		return self.T
def get_T_X(self):
387	def get_T_X(self) :
388
389		if self.T_x  is not None :
390			return self.T_x
391
392		n_states  = len( self.states )
393		n_symbols = len( self.alphabet )
394
395		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
396
397		for tr in self.transitions :
398			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
399
400		self.T_x = T_x
401		return self.T_x
def get_msp_qw(self, exact_state_cap: int = 1000, verbose: bool = True):
405	def get_msp_qw(
406		self,
407		exact_state_cap: int = 1000,
408		verbose: bool = True,
409	):
410		if self.msp is not None:
411			return self.msp
412
413		try : 
414
415			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
416
417			self.msp = compute_msp_exact(
418				T_x=self.get_Tx_fractional(),
419				pi=self.get_fractional_stationary_distribution(),
420				n_states=len(self.states),
421				alphabet=self.alphabet,
422				exact_state_cap=1000,
423				bool = True
424			)
425
426			return self.msp 
427
428		except RuntimeError as e :
429			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
430
431		return self.get_msp()
def get_msp( self, exact_state_cap: int = 175000, jsd_eps: float = 1e-07, k_ann: int = 50, verbose=True):
433	def get_msp(
434		self,
435		exact_state_cap: int = 175_000,
436		jsd_eps:         float = 1e-7,
437		k_ann:           int   = 50,
438		verbose                = True,
439	) :
440
441		if self.msp is not None:
442			return self.msp
443	 
444		T_x = self.get_T_X()
445		pi  = self.get_stationary_distribution()
446
447		T_stacked      = np.stack(T_x)
448		n_symbols      = len(self.alphabet)
449		n_input_states = T_stacked.shape[1]
450
451		print( "\nComputing Mixed State Presentation\n" )
452
453		self.msp = compute_msp( 
454			T_x=T_x,
455			pi=pi,
456			n_states=len(self.states),
457			alphabet=self.alphabet,
458			exact_state_cap=exact_state_cap,
459			verbose=verbose
460		)
461
462		return self.msp
def get_reverse_am(self):
464	def get_reverse_am(self) :
465
466		if self.reverse_am is not None:
467			return self.reverse_am
468
469		pi = self.get_stationary_distribution()
470		self.reverse_am = copy.deepcopy(self)
471		
472		new_transitions = []
473		for tr in self.transitions:
474			i = tr.target_state_idx
475			j = tr.origin_state_idx
476			
477			p_reversed = (pi[j] * tr.prob) / pi[i]
478			
479			new_transitions.append(
480				Transition(
481					origin_state_idx=i,
482					target_state_idx=j,
483					prob=p_reversed,
484					symbol_idx=tr.symbol_idx
485				)
486			)
487
488		self.reverse_am.set_transitions(new_transitions)
489
490		if self.reverse_am.is_epsilon_machine():
491			return self.reverse_am
492
493		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
494
495		self.reverse_am.set_states( rmsp.states )
496		self.reverse_am.set_transitions( rmsp.transitions )
497		self.reverse_am.msp = rmsp
498		self.reverse_am.start_state = 0
499
500		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
501		self.reverse_am.minimize()
502
503		return self.reverse_am
def get_Tx_fractional(self) -> list[list[list[fractions.Fraction]]]:
507	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
508
509		self.to_q_weighted()
510
511		n_states  = len( self.states )
512		n_symbols = len( self.alphabet )
513
514		T_x = []
515
516		for x in range( n_symbols ) :
517			T_x.append( [] )
518			for i in range( n_states ) :
519				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
520
521		for tr in self.transitions :
522			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
523
524		return T_x
def get_T_sympy(self):
526	def get_T_sympy( self ) :
527
528		self.to_q_weighted()
529
530		n = len( self.states )
531		T = sympy.zeros( n, n )
532
533		for tr in self.transitions :
534			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
535
536		return T
def get_fractional_stationary_distribution(self):
538	def get_fractional_stationary_distribution(self) :
539
540		T = self.get_T_sympy()
541
542		if self.pi_fractional is not None :
543			return self.pi_fractional
544
545		G = self.as_digraph()
546
547		if not nx.is_strongly_connected(G):
548			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
549
550		self.pi_fractional = solve_for_pi_fractional( T )
551
552		return self.pi_fractional
def get_stationary_distribution(self):
554	def get_stationary_distribution(self):
555
556		if self.pi is not None :
557			return self.pi
558
559		G = self.as_digraph()
560		
561		if not nx.is_strongly_connected(G):
562			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
563
564		T = self.get_transition_matrix()
565		return solve_for_pi( T )		
def C_mu(self):
571	def C_mu( self ) :
572
573		"""
574		c_mu (statistical complexity, or forecasting complexity)
575
576		Interpretations
577			- The amount of historical information a process stores.
578			- The amount of structure in a process.
579		"""
580
581		m = self.get_complexity_measure_if_exists( "C_mu" )
582
583		if m is not None :
584			return m
585
586		pi = self.get_stationary_distribution()
587
588		h = 0
589		for i, pr in enumerate( pi ) :
590			
591			if pr < self.EPS :
592				continue
593
594			h += -pr * np.log2( pr )
595
596		self.set_complexity_measure( "C_mu", h )
597
598		return h

c_mu (statistical complexity, or forecasting complexity)

Interpretations - The amount of historical information a process stores. - The amount of structure in a process.

def h_mu(self):
602	def h_mu( self ) :
603
604		"""
605		h_mu (entropy rate)
606
607		Interpretation
608			- Intrinsic Randomness in the process.
609		"""
610
611		m = self.get_complexity_measure_if_exists( "h_mu" )
612
613		if m is not None :
614			return m
615
616		T  = self.get_transition_matrix()
617		pi = self.get_stationary_distribution()
618
619		n_states = pi.size
620
621		h = 0
622		for i, pr in enumerate( pi ) :
623
624			if pr < self.EPS :
625				continue
626
627			row_entropy = 0
628			for j in range( len( pi ) ) :
629
630				if T[ i, j ]  < self.EPS :
631					continue
632
633				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
634
635			h += pr * row_entropy
636
637		self.set_complexity_measure( "h_mu", h )
638
639		return h

h_mu (entropy rate)

Interpretation - Intrinsic Randomness in the process.

def H_1(self):
643	def H_1(self):
644
645		"""
646		H_1 (single-symbol entropy, or marginal entropy)
647		Interpretations:
648			- How uncertain you are about a single measurement with no context.
649			- The total information content per symbol before seeing any neighbors.
650			- The maximum possible entropy rate, achieved by a memoryless process.
651		"""
652
653		m = self.get_complexity_measure_if_exists("H_1")
654		if m is not None:
655			return m
656
657		pi  = self.get_stationary_distribution()
658		T_X = self.get_T_X()  # dict: symbol -> matrix
659
660		h = 0.0
661		for T_x in T_X:
662			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
663			p_sym = 0.0
664			for i, pr in enumerate(pi):
665				if pr < self.EPS:
666					continue
667				p_sym += pr * T_x[i, :].sum()
668
669			if p_sym < self.EPS:
670				continue
671			h -= p_sym * np.log2(p_sym)
672
673		self.set_complexity_measure("H_1", h)
674		return h

H_1 (single-symbol entropy, or marginal entropy) Interpretations: - How uncertain you are about a single measurement with no context. - The total information content per symbol before seeing any neighbors. - The maximum possible entropy rate, achieved by a memoryless process.

def rho_mu(self):
678	def rho_mu(self) :
679		
680		"""
681		rho_mu (anticipated information, or single-symbol redundancy)
682		Interpretations:
683			- How much the past reduces your uncertainty about the next symbol.
684			- The portion of each symbol's information that was already predictable.
685			- The difference between naive per-symbol uncertainty and true randomness.
686			- Equal to H_1 - h_mu: the gap closed by knowing context.
687		"""
688
689		m = self.get_complexity_measure_if_exists("rho_mu")
690		
691		if m is not None:
692			return m
693
694		rho = self.H_1() - self.h_mu()
695		
696		self.set_complexity_measure("rho_mu", rho)
697		
698		return rho

rho_mu (anticipated information, or single-symbol redundancy) Interpretations: - How much the past reduces your uncertainty about the next symbol. - The portion of each symbol's information that was already predictable. - The difference between naive per-symbol uncertainty and true randomness. - Equal to H_1 - h_mu: the gap closed by knowing context.

def block_convergence(self):
702	def block_convergence( self ) :
703
704		"""
705		block_convergence
706			- Estimates a range of complexity measures using block entropy convergence.
707			- E, S, T, T_L, H_L, h_mu_L, H_sync
708		"""
709
710		trs = [ [] for _ in range( len( self.states ) ) ]
711		for tr in self.transitions :
712			trs[ tr.origin_state_idx ].append( ( 
713				tr.symbol_idx, 
714				float( tr.prob ),
715				tr.target_state_idx ) )
716
717		pi = self.get_stationary_distribution()
718
719		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
720		branches = [(1.0, list(state_dist))]
721
722		print( "\nComputing Block Entropy\n" )
723
724		C = am_fast.block_entropy_convergence(
725			h_mu            = self.h_mu(),
726			n_states        = len( self.states ),
727			n_symbols       = len( self.alphabet ),
728			convergence_tol = 1e-6,
729			precision       = 10,
730			eps             = 1e-25,
731			branches        = branches,
732			trans           = trs,
733			max_branches    = 30_000_000
734		)
735
736		print( "Done\n" )
737
738		self.set_complexity_measure( f"E",          C.E )
739		self.set_complexity_measure( f"S",          C.S )
740		self.set_complexity_measure( f"T_inf",      C.T )
741		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
742		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
743		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
744		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
745
746		return C

block_convergence - Estimates a range of complexity measures using block entropy convergence. - E, S, T, T_L, H_L, h_mu_L, H_sync

def E(self):
750	def E( self ) :
751
752		"""
753		E (excess entropy)
754
755		Interpretations:
756			- Mutual information from past to future I(Past;Future).
757			- How much total information from the past reduces uncertainty in the future
758			- Amount of (apparent) stored information
759
760		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
761		https://arxiv.org/abs/1309.3792
762		"""
763
764		m = self.get_complexity_measure_if_exists( "E" )
765
766		if m is not None :
767			return m
768
769		try : 
770			msp = self.get_msp()
771			E, S, T = msp.get_E_S_T()
772			self.set_complexity_measure( "E", E )
773			self.set_complexity_measure( "S", S )
774			self.set_complexity_measure( "T_inf", T )
775			
776		except Exception as e :
777
778			print( f"MSP failed {e}" )
779			exit()
780
781			C = self.block_convergence()	
782			E = C.E
783			self.set_complexity_measure( "E", E )
784
785		return E

E (excess entropy)

Interpretations: - Mutual information from past to future I(Past;Future). - How much total information from the past reduces uncertainty in the future - Amount of (apparent) stored information

Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792

def S(self):
789	def S( self ) :
790
791		"""
792		S (synchronization information)
793
794		Interpretation
795			-synchronization information
796
797		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
798		https://arxiv.org/abs/1309.3792
799		"""
800
801		m = self.get_complexity_measure_if_exists( "S" )
802
803		if m is not None :
804			return m
805
806		try : 
807			msp = self.get_msp()
808			E, S, T = msp.get_E_S_T()
809			self.set_complexity_measure( "E", E )
810			self.set_complexity_measure( "S", S )
811			self.set_complexity_measure( "T_inf", T )
812
813		except Exception as e :
814			print( f"{e} \nFalling back to iterative estimation.")
815			exit()
816
817			C = self.block_convergence()	
818			S = C.S
819			self.set_complexity_measure( "S", S )
820
821		return S

S (synchronization information)

Interpretation -synchronization information

Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792

def T_inf(self):
825	def T_inf( self ) :
826
827		"""
828		T_inf (transient information)
829
830		Interpretation
831			-Transient information
832
833		Exact Complexity: The Spectral Decomposition of Intrinsic Computation
834		https://arxiv.org/abs/1309.3792
835		"""
836
837		m = self.get_complexity_measure_if_exists( "T_inf" )
838
839		if m is not None :
840			return m
841
842		try : 
843			msp = self.get_msp()
844			E, S, T = msp.get_E_S_T()
845			self.set_complexity_measure( "E", E )
846			self.set_complexity_measure( "S", S )
847			self.set_complexity_measure( "T_inf", T )
848
849		except Exception as e :
850			print( f"{e} \nFalling back to iterative estimation.")
851			exit()
852			C = self.block_convergence()	
853			T_inf = C.T
854			self.set_complexity_measure( "T_inf", T_inf )
855
856		return T_inf

T_inf (transient information)

Interpretation -Transient information

Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792

def Chi(self):
860	def Chi( self ) :
861
862		"""
863		Chi (crypticity)
864
865		Interpretation
866			- Difference between internal stored information and apparent information to an observer
867			- How muching information is hiding in the system
868		"""
869
870		m = self.get_complexity_measure_if_exists( "Chi" )
871
872		if m is not None :
873			return m
874
875		Chi = self.C_mu() - self.E()
876
877		self.set_complexity_measure( "Chi", Chi )
878
879		return Chi

Chi (crypticity)

Interpretation - Difference between internal stored information and apparent information to an observer - How muching information is hiding in the system

def k_X(self):
883	def k_X( self ) :
884
885		"""
886		k_x (cryptic order)
887
888		Interpretation
889			
890		"""
891
892		m = self.get_complexity_measure_if_exists( "k_X" )
893
894		if m is not None :
895			return m
896
897		raise NotImplementedError("Not implemented")

k_x (cryptic order)

Interpretation

def R(self):
901	def R( self ) :
902
903		"""
904		R (Markov order)
905		Interpretations:
906			- The smallest L such that h_mu(L) = h_mu.
907			- The context length beyond which additional history gives no predictive benefit.
908			- The memory length of the process — how far back correlations extend.
909			- The ideal context window length for a predictor.
910			- R = 0 means the process is IID (memoryless).
911			- R = infinity means the process has unbounded memory (non-finitary).
912			— only report if convergence was clean, otherwise None
913		"""
914
915		m = self.get_complexity_measure_if_exists( "R" )
916
917		if m is not None :
918			return m
919
920		raise NotImplementedError("Not implemented")

R (Markov order) Interpretations: - The smallest L such that h_mu(L) = h_mu. - The context length beyond which additional history gives no predictive benefit. - The memory length of the process — how far back correlations extend. - The ideal context window length for a predictor. - R = 0 means the process is IID (memoryless). - R = infinity means the process has unbounded memory (non-finitary). — only report if convergence was clean, otherwise None

def b_mu(self):
924	def b_mu( self ) :
925
926		"""
927		b_mu - requires reverse HMM
928
929
930		Interpretation
931			
932		"""
933
934		m = self.get_complexity_measure_if_exists( "b_mu" )
935
936		if m is not None :
937			return m
938
939		raise NotImplementedError("Not implemented")

b_mu - requires reverse HMM

Interpretation

def r_mu(self):
943	def r_mu( self ) :
944
945		"""
946		r_mu - requires reverse HMM
947
948
949		Interpretation
950			
951		"""
952
953		m = self.get_complexity_measure_if_exists( "r_mu" )
954
955		if m is not None :
956			return m
957
958		raise NotImplementedError("Not implemented")

r_mu - requires reverse HMM

Interpretation

def q_mu(self):
962	def q_mu( self ) :
963
964		"""
965		q_mu 
966
967		Interpretation
968		"""
969
970		m = self.get_complexity_measure_if_exists( "q_mu" )
971
972		if m is not None :
973			return m
974
975		raise NotImplementedError("Not implemented")

q_mu

Interpretation

def causally_reversible(self):
981	def causally_reversible( self ) :
982
983		"""
984		causally_reversible
985		"""
986
987		m = self.get_complexity_measure_if_exists( "causally_reversible" )
988
989		if m is not None :
990			return m
991
992		raise NotImplementedError("Not implemented")

causally_reversible

def microscopically_reversible(self):
 996	def microscopically_reversible( self ) :
 997
 998		"""
 999		microscopically_reversible
1000		"""
1001
1002		m = self.get_complexity_measure_if_exists( "microscopically_reversible" )
1003
1004		if m is not None :
1005			return m
1006
1007		raise NotImplementedError("Not implemented")

microscopically_reversible

def is_row_stochastic(self):
1011	def is_row_stochastic(self) :
1012
1013		"""
1014		is_row_stochastic
1015			-All states have outgoing transition probabilities which sum 1
1016		"""
1017
1018		sums = np.zeros( len( self.states ) )
1019		for tr in self.transitions :
1020			sums[ tr.origin_state_idx ] += tr.prob
1021		return np.allclose( sums, 1.0 )

is_row_stochastic -All states have outgoing transition probabilities which sum 1

def is_unifilar(self):
1025	def is_unifilar(self) :
1026
1027		"""
1028		is_unifilar
1029			-All states have no more than one outgoing transition per-symbol
1030		"""
1031
1032		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
1033
1034		for tr in self.transitions : 
1035		
1036			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
1037				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
1038		
1039			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
1040				return False
1041		
1042		return True

is_unifilar -All states have no more than one outgoing transition per-symbol

def is_strongly_connected(self):
1046	def is_strongly_connected(self) :
1047
1048		"""
1049		is_strongly_connected
1050			-Every state is reachable from every other state
1051		"""
1052
1053		return nx.is_strongly_connected( self.as_digraph() )

is_strongly_connected -Every state is reachable from every other state

def is_aperiodic(self):
1057	def is_aperiodic(self) :
1058
1059		"""
1060		is_aperiodic
1061			-"A strongly connected directed graph is aperiodic 
1062				if there is no integer k > 1 that divides the length of every cycle in the graph."
1063		"""
1064
1065		return nx.is_aperiodic( self.as_digraph() )

is_aperiodic -"A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."

def is_minimal_as_dfa(self, topological_only: bool, verbose=True):
1069	def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) :
1070
1071		with_probs = not topological_only
1072
1073		# Construct the DFA
1074		dfa = self.as_dfa( with_probs=with_probs )
1075
1076		# Minimize the DFA
1077		#dfa = dfa.minify(retain_names=True)
1078		dfa = am_fast.minify_cpp( dfa, retain_names=True )
1079
1080		# check we have minimal number of states
1081		if len( dfa.states ) != len( self.states ) :
1082			if verbose : 
1083				print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" )
1084			return False
1085
1086		return True
def is_topological_epsilon_machine(self, verbose=True):
1088	def is_topological_epsilon_machine( self, verbose=True ) :
1089
1090		"""
1091		is_topological_epsilon_machine
1092			-see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024
1093		"""
1094
1095		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1096			if verbose : 
1097				print( f"Either non unifilar or not strongly connected" )
1098			return False
1099		else :
1100			return self.is_minimal_as_dfa( topological_only=True, verbose=verbose )

is_topological_epsilon_machine -see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024

def is_epsilon_machine(self, verbose=True):
1102	def is_epsilon_machine( self, verbose=True ) :
1103
1104		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1105			if verbose : 
1106				print( f"Either non unifilar or not strongly connected" )
1107			return False
1108		else :
1109			return self.is_minimal_as_dfa( topological_only=False, verbose=verbose )
def minimize(self, retain_names: bool = True):
1115	def minimize(self, retain_names: bool = True):
1116
1117		"""
1118		minimize
1119			-Simplify to minimal form, using Myhill Nerode
1120		"""
1121
1122		if self.is_minimal :
1123			return
1124
1125		start = time.perf_counter()
1126
1127		if not self.is_unifilar():
1128			raise ValueError(
1129				"DFA minimization is not valid for non-unifilar HMMs"
1130			)
1131
1132		was_strongly_connected = self.is_strongly_connected()
1133
1134		was_row_stochastic = self.is_row_stochastic()
1135		n_states_before = len(self.states)
1136
1137		dfa = self.as_dfa(with_probs=True)
1138
1139		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1140		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1141
1142		# Build lookup from original state index -> CausalState object
1143		orig_state   = {i: s for i, s in enumerate(self.states)}
1144		eq_list      = list(min_dfa.states)
1145
1146		start_eq = min_dfa.initial_state
1147		
1148		# Separate the start state, then sort the rest by the 
1149		# smallest original state index inside each equivalence class.
1150		other_eqs = [eq for eq in eq_list if eq != start_eq]
1151		other_eqs.sort(key=lambda eq: min(eq))
1152
1153		# Recombine so start eq comes first, followed by the sorted remaining classes
1154		eq_list = [start_eq] + other_eqs
1155		# ----------------------------------------------------------
1156
1157		# Recompute eq_to_idx with the new ordering
1158		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1159
1160		# new_start is now guaranteed to be 0
1161		new_start = 0
1162
1163		# Map each original state index -> its equivalence class
1164		# Guard: minify() silently drops unreachable states
1165		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1166
1167		# Build lookup from original state index -> its transitions
1168		orig_trs = defaultdict(list)
1169		for t in self.transitions:
1170			orig_trs[t.origin_state_idx].append(t)
1171
1172		new_trs = []
1173		for eq in min_dfa.states:
1174			rep        = next(iter(eq))
1175			origin_idx = eq_to_idx[eq]
1176			for t in orig_trs[rep]:
1177				target_eq  = orig_to_eq[t.target_state_idx]
1178				target_idx = eq_to_idx[target_eq]
1179				new_trs.append(Transition(
1180					origin_state_idx = origin_idx,
1181					target_state_idx = target_idx,
1182					prob             = t.prob,
1183					symbol_idx       = t.symbol_idx,
1184				))
1185
1186		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1187
1188		# Compute new names
1189		if retain_names:
1190			new_names = [
1191				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1192				else members[0].name
1193				for members in members_list
1194			]
1195		else:
1196			new_names = [str(j) for j in range(len(eq_list))]
1197
1198		old_name_to_new_name = {
1199			m.name: new_names[j]
1200			for j, members in enumerate(members_list)
1201			for m in members
1202		}
1203
1204		# Build the new states, preserving classes and isomorphs regardless of naming
1205		new_states = []
1206		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1207			classes   = set().union(*(m.classes for m in members))
1208			isomorphs = {
1209				old_name_to_new_name.get(iso, iso)
1210				for m in members
1211				for iso in m.isomorphs
1212				if old_name_to_new_name.get(iso, iso) != name
1213			}
1214			new_states.append(CausalState(
1215				name      = name,
1216				classes   = classes,
1217				isomorphs = isomorphs,
1218			))
1219
1220		self.set_states(new_states)
1221		self.set_transitions(new_trs)
1222		self.start_state = new_start
1223
1224		if n_states_before == len(new_states) :
1225			print( f"{n_states_before} state HMM was already minimal." )
1226		else :
1227			print( f"Minimized from {n_states_before} to {len(new_states)}" )
1228
1229		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1230			raise RuntimeError(
1231				f"Minimization broke strongly connected"
1232			)
1233
1234		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1235			raise RuntimeError(
1236				f"Minimization broke row stochasticity"
1237			)
1238
1239		self.is_minimal = True

minimize -Simplify to minimal form, using Myhill Nerode

def collapse_to_largest_strongly_connected_subgraph(self, rename_states=True):
1246	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1247
1248		# get equivalent networkx graph
1249		G = self.as_digraph()
1250
1251		# if already strongly connected, nothing to do
1252		if not nx.is_strongly_connected( G ) :
1253
1254			start = time.perf_counter()
1255			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1256
1257			# decompose into strongly connected components and sort by length
1258			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1259			subgraph_nodes.sort(key=len)
1260			component_state_set = subgraph_nodes[-1]
1261
1262			# Take the largest strongly connected component (as list of state names)
1263			component_states = sorted( list( component_state_set ) )
1264
1265			# make temporary copies of the old transitions and states
1266			old_transitions = [
1267				Transition(
1268					origin_state_idx=tr.origin_state_idx,
1269					target_state_idx=tr.target_state_idx,
1270					prob=tr.prob,
1271					symbol_idx=tr.symbol_idx
1272				)
1273
1274				for tr in self.transitions
1275			]
1276
1277			old_states = [
1278				CausalState(
1279					name=s.name,
1280					classes=s.classes,
1281					isomorphs=s.isomorphs
1282				)
1283				for s in self.states
1284			]
1285
1286			self.set_states(
1287				states=[ 
1288					state
1289					for i, state in enumerate( old_states ) if i in component_state_set
1290				]
1291			)
1292
1293			# we will build new transition list based on those belonging to the component
1294			self.set_transitions( transitions= [] )
1295
1296			# for tracking which new transitions leave each state
1297			transitions_from_state = { state : set() for state in component_states }
1298			new_transitions = []
1299
1300			for tr in old_transitions :
1301
1302				origin_state_name = old_states[ tr.origin_state_idx ].name
1303				target_state_name = old_states[ tr.target_state_idx ].name
1304
1305				# skip transitions that connect separate strongly connected components
1306				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1307					continue
1308
1309				# track transitions (by index in new transitions list) that leave this state
1310				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1311
1312				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1313				my_target_state_idx = self.state_idx_map[ target_state_name ]
1314
1315				new_transitions.append( 
1316					Transition(
1317						origin_state_idx=my_origin_state_idx,
1318						target_state_idx=my_target_state_idx,
1319						prob=tr.prob,
1320						symbol_idx=tr.symbol_idx
1321					)  )
1322
1323			self.set_transitions( transitions=new_transitions )
1324
1325			# if we removed an outgoing transition from a state, we need to distribute its probability 
1326			# among the remaining outgoing transitions from the state
1327			for state in component_states :
1328				
1329				# get the set of transitions leaving this state
1330				state_trs = transitions_from_state[ state ]
1331
1332				# sum the probabilities of the outgoing transitions from the state
1333				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1334
1335				# how much probability is missing
1336				diff = 1.0 - p_sum
1337
1338				# if significant difference
1339				if abs( diff ) > self.EPS : 
1340
1341					# calculate how much of the difference each transition gets
1342					adjustment = diff / len( state_trs )
1343					
1344					# update the transitions
1345					for i in state_trs :
1346
1347						# transition with origional probability
1348						tr = self.transitions[ i ]
1349
1350						# adjusted probability
1351						self.transitions[ i ] = Transition(
1352							origin_state_idx=tr.origin_state_idx, 
1353							target_state_idx=tr.target_state_idx, 
1354							prob=tr.prob + adjustment, 
1355							symbol_idx=tr.symbol_idx )
1356
1357			if rename_states :
1358				self.set_states( [
1359					CausalState( name=f"{i}" )
1360					for i, s in enumerate( self.states )
1361				] )
def to_q_weighted(self, denominator_limit=1000):
1364	def to_q_weighted( self, denominator_limit=1000 ) :
1365
1366		"""
1367		to_q_weighted
1368			-Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq
1369		"""
1370
1371		if self.is_q_weighted :
1372			return
1373
1374		if not self.is_row_stochastic() :
1375			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1376
1377		t_from = [[] for _ in range(len(self.states))]
1378
1379		for i, tr in enumerate( self.transitions ) :
1380			t_from[ tr.origin_state_idx ].append( i )
1381
1382		new_transitions = []
1383		for t_list in t_from :
1384			
1385			if not t_list : 
1386				continue
1387
1388			p_q_sum = Fraction(0,1)
1389			p_qs = []
1390
1391			for t_idx in t_list :
1392			
1393				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1394				p_q_sum += p_q
1395				p_qs.append( p_q )
1396
1397			if p_q_sum != Fraction(1,1) :
1398				
1399				max_pq_i = np.argmax( p_qs )
1400				max_oq = p_qs[ max_pq_i ]
1401
1402				diff = p_q_sum - Fraction(1,1)
1403
1404				# recurse with higher resolution
1405				if diff > max_oq :
1406					return self.to_q_weighted( denominator_limit*10 )
1407				else :
1408					p_qs[ max_pq_i ] -= diff
1409
1410			for i, t_idx in enumerate( t_list ) :	
1411				new_transitions.append( 
1412					Transition( 
1413						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1414						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1415						prob=float(p_qs[ i ]),
1416						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1417						pq=p_qs[ i ]
1418					)
1419				)
1420
1421		self.set_transitions( new_transitions )
1422		self.is_q_weighted = True

to_q_weighted -Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq

def generate_data( self, file_prefix: str, n_gen: int, include_states: bool, isomorphic_shifts: set[int] = None, random_seed: int = 42) -> dict[any]:
1428	def generate_data(
1429		self,
1430		file_prefix: str,
1431		n_gen: int,
1432		include_states: bool,
1433		isomorphic_shifts : set[int]=None,
1434		random_seed : int=42 ) -> dict[any] : 
1435
1436		trs = [ [] for _ in range( len( self.states ) ) ]
1437		for tr in self.transitions :
1438			trs[ tr.origin_state_idx ].append( ( 
1439				tr.symbol_idx, 
1440				float( tr.prob ),
1441				tr.target_state_idx ) )
1442
1443		data = am_fast.generate_data(
1444			n_gen=n_gen,
1445			start_state=self.start_state,
1446			transitions=trs,
1447			alphabet=sorted(list(self.alphabet)),
1448			include_states=include_states,
1449			random_seed=random_seed
1450		)
1451
1452		if isomorphic_shifts is not None :
1453
1454			if not include_states :
1455				raise ValueError( "Isomorphic inversion requires include_states=True" )
1456
1457			data[ "isomorphic_shifts" ] = {}
1458
1459			for shift in isomorphic_shifts :
1460
1461				shifted = HMM.isomorphic_shift(
1462					symbol_indices=data[ "symbol_index" ], 
1463					state_indices=data[ "state_index" ], 
1464					m=self,
1465					shift=shift
1466				)
1467
1468				data[ "isomorphic_shifts" ][ shift ] = {
1469					"symbol_index" : shifted[ "symbol_index" ],
1470					"state_index"  : shifted[ "state_index" ]
1471				}
1472
1473		am_fast.save_data(
1474			data=data,
1475			file_prefix=file_prefix,
1476			alphabet=sorted(list(self.alphabet)),
1477			n_states=len( self.states ),
1478			start_state=self.start_state,
1479			random_seed=random_seed,
1480			metadata=self.get_metadata() )
1481
1482		return data
def draw_graph( self, engine: str = 'dot', output_dir: pathlib.Path = PosixPath('.'), show: bool = True) -> None:
1488	def draw_graph(
1489		self,
1490		engine     : str  = 'dot',
1491		output_dir : Path = Path('.'),
1492		show       : bool = True
1493	) -> None :
1494
1495		G = self.as_digraph()
1496
1497		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1498		
1499		am_vis.draw_graph( 
1500			self, 
1501			output_dir=output_dir, 
1502			title="am_graph", 
1503			view=show, 
1504			subgraphs=subgraphs, 
1505			engine=engine )
def as_digraph(self):
1512	def as_digraph( self ) :
1513
1514		"""
1515		Build nx DiGraph based on the HMM, without edge probabilities or symbols
1516		"""
1517
1518		G = nx.DiGraph()
1519		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1520
1521		for tr in self.transitions :
1522			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1523
1524		return G

Build nx DiGraph based on the HMM, without edge probabilities or symbols

def as_dfa(self, with_probs: bool):
1526	def as_dfa( self, with_probs : bool ) :
1527
1528		precision=8
1529
1530		def edge_label( symb, prob ) :
1531			return f"({symb},{round(prob, precision)})"
1532
1533		# Build states, symbols, and transitions 
1534		dfa_states  = { i for i, _ in enumerate( self.states ) }
1535			
1536		if not with_probs :
1537			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1538		else : 
1539			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1540
1541		dfa_transitions = defaultdict(dict)
1542
1543		if not with_probs :
1544			for t in self.transitions :
1545				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1546		else :
1547			for t in self.transitions :
1548				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1549
1550		# Construct the DFA
1551		return DFA(
1552			states=dfa_states,
1553			input_symbols=dfa_symbols,
1554			transitions=dfa_transitions,
1555			initial_state=self.start_state,
1556			allow_partial=True,
1557			final_states={ 
1558				s for s in dfa_states
1559			}
1560		)