amachine.am_hmm
1from __future__ import annotations 2 3from abc import abstractmethod 4from typing import override 5import gc 6import copy 7from collections import deque 8from collections import defaultdict 9import json 10from fractions import Fraction 11from pathlib import Path 12import warnings 13import time 14from dataclasses import dataclass, asdict, field 15 16import networkx as nx 17 18from automata.fa.dfa import DFA 19 20import sympy 21import numpy as np 22 23from .am_machine import Machine 24from .am_symbol import Symbol 25 26from .am_msp import MSP, compute_msp, compute_msp_exact 27from .am_solve import solve_for_pi, solve_for_pi_fractional 28from . import am_vis 29from . import am_fast 30 31from .am_causal_state import CausalState 32from .am_transition import Transition 33 34from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence 35from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu 36 37class HMM(Machine) : 38 39 @staticmethod 40 def isomorphic_shift( 41 input_symbol_indices: np.ndarray, 42 input_state_indices: np.ndarray, 43 m : HMM, 44 shift : int = 1 45 ) -> dict[str, np.ndarray]: 46 47 if not any( state.isomorphs for state in m.states ): 48 raise ValueError("HMM has no states with isomorphs") 49 50 inputs = np.asarray(input_symbol_indices) 51 states = np.asarray(input_state_indices) 52 53 n_states = len(m.states) 54 55 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 56 57 for tr in m.transitions: 58 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 59 60 # Build isomorph remapping: identity by default, overridden where isomorphs exist 61 iso_table = np.arange(n_states, dtype=np.int32) 62 63 for i, state in enumerate(m.states): 64 if state.isomorphs: 65 isomorph = sorted(state.isomorphs)[0] 66 iso_table[ i ] = m.state_idx_map[isomorph] 67 68 for i, state in enumerate(m.states): 69 if state.isomorphs: 70 71 # extend the isomorph list to include the identity 72 isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] ) 73 74 # find the identity index 75 pos = isormorphs_with_identity.index( i ) 76 77 # cyclical shift 78 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 79 80 origins = states[:-1] 81 targets = states[1:] 82 83 out_origins = iso_table[origins] 84 out_targets = iso_table[targets] 85 86 inv_sym = tr_sym_table[ out_origins, out_targets ] 87 inv_sts = np.empty(states.size, dtype=states.dtype) 88 89 inv_sts[:-1] = out_origins 90 inv_sts[-1] = out_targets[-1] 91 92 return { 93 "symbol_index": inv_sym.astype(inputs.dtype), 94 "state_index": inv_sts, 95 } 96 97 def __init__( 98 self, 99 states : list[CausalState] | None = None, 100 transitions : list[Transition] | None = None, 101 start_state : int = 0, 102 alphabet : list[str] | None = None, 103 isoclasses : dict[str,set[str]] | None = None, 104 name : str = "", 105 description : str = "" ) : 106 107 self.alphabet = alphabet or [] 108 self.states = states or [] 109 self.transitions = transitions or [] 110 111 self.set_alphabet( self.alphabet ) 112 self.set_states( self.states ) 113 self.set_transitions( self.transitions ) 114 115 self.start_state = start_state 116 self.isoclasses = isoclasses or {} 117 118 self.name = name 119 self.description = description 120 121 # to be depreciated 122 self.isoclass = None 123 124 # --- derived -------- 125 126 self.complexity : dict[str, any] = {} 127 128 self.symbol_idx_map : dict[str,int] = {} 129 self.state_idx_map : dict[str,int] = {} 130 131 self.pi_fractional = None 132 self.pi = None 133 self.T = None 134 self.msp : MSP = None 135 self.reverse_am : HMM = None 136 self.is_q_weighted = False 137 self.is_minimal = False 138 139 # --- const ---------- 140 141 self.EPS = 1e-12 142 143 #-------------------------------------------------------# 144 # Overrides / Getters # 145 #-------------------------------------------------------# 146 147 @override 148 def get_states(self) -> list[CausalStates] : 149 return self.states 150 151 @override 152 def get_transitions(self) -> list[Transition] : 153 return self.transitions 154 155 @override 156 def get_alphabet(self) -> list[Symbol]: 157 return [ Symbol(a) for a in self.alphabet ] 158 159 #-------------------------------------------------------# 160 # Serialization # 161 #-------------------------------------------------------# 162 163 164 def to_dict(self) : 165 return { 166 "name" : self.name, 167 "description" : self.description, 168 "states" : [ asdict(state) for state in self.states ], 169 "transitions" : [ asdict(transition) for transition in self.transitions ], 170 "alphabet" : self.alphabet 171 } 172 173 def from_dict( self, config ) : 174 175 self.name = config.name 176 self.description = config.description 177 178 self.set_states( states=[ 179 CausalState( 180 name=state[ "name" ], 181 classes=set( state[ "classes" ] ) 182 ) 183 for state in config.states 184 ] ) 185 186 self.set_transitions( transitions=[ 187 Transition( 188 origin_state_idx=tr[ "origin_state_idx" ], 189 target_state_idx=tr[ "target_state_idx" ], 190 prob=tr[ "prob" ], 191 symbol_idx=tr[ "symbol_idx" ] 192 ) 193 for tr in config.transitions 194 ] ) 195 196 self.set_alphabet( alphabet=config.alphabet ) 197 198 def save_config( 199 self, 200 path : Path, 201 with_complexity : bool = False, 202 with_block_convergence : bool = False, 203 with_structural_properties : bool = False, 204 with_causal_properties : bool = False ) : 205 206 config = self.to_dict() 207 208 if with_complexity : 209 210 complexity = self.get_complexities( 211 with_block_convergence=with_block_convergence, 212 with_reversal=False # not implemented 213 ) 214 215 config[ "complexity" ] = complexity 216 217 config[ "structural_properties" ] = { 218 "unifilar" : self.is_unifilar(), 219 "row_stochastic" : self.is_row_stochastic(), 220 "strongly_connected" : self.is_strongly_connected(), 221 "aperiodic" : self.is_aperiodic(), 222 "minimal" : self.is_minimal_as_dfa( topological_only=False ), 223 "topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ), 224 "is_epsilon_machine" : self.is_epsilon_machine() 225 } 226 227 # Not implemented 228 # config[ "causal_properties" ] = { 229 # "k_X" : self.k_X(), 230 # "R" : self.R(), 231 # "causally_reversible" : self.causally_reversible(), 232 # "microscopically_reversible" : self.microscopically_reversible(), 233 # } 234 235 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 236 json.dump( config, f, ensure_ascii=False, indent=2, default=list ) 237 238 def from_file( self, path : Path ) : 239 with open( Path / "am_config.json", "r" ) as f: 240 config = json.load(f) 241 self.from_dict() 242 243 #-------------------------------------------------------# 244 # Setters and State Management # 245 #-------------------------------------------------------# 246 247 def _invalidate(self) : 248 249 """ 250 Invalidate 251 252 Reset all derived properties so they will be recomputed when requested later 253 """ 254 255 self.complexity = {} 256 self.T = None 257 self.T_x = None 258 self.pi = None 259 self.pi_fractional = None 260 self.msp = None 261 self.reverse_am = None 262 self.is_q_weighted = False 263 self.is_minimal = False 264 265 def set_states( self, states : list[CausalState] ) : 266 self._invalidate() 267 self.states = states.copy() 268 self.state_idx_map = {} 269 for idx, state in enumerate( self.states ) : 270 self.state_idx_map[ state.name ] = idx 271 272 def set_alphabet( self, alphabet : list[str] ) : 273 274 self._invalidate() 275 276 old_alphabet = self.alphabet.copy() 277 278 aSet = set() 279 aSet.update( alphabet ) 280 281 self.alphabet = sorted(list(aSet)) 282 283 self.symbol_idx_map = {} 284 for idx, symbol in enumerate( self.alphabet ) : 285 self.symbol_idx_map[ symbol ] = idx 286 287 for i, tr in enumerate( self.transitions ) : 288 symbol = old_alphabet[ tr.symbol_idx ] 289 self.transitions[ i ] = Transition( 290 origin_state_idx=tr.origin_state_idx, 291 target_state_idx=tr.target_state_idx, 292 prob=tr.prob, 293 symbol_idx=self.symbol_idx_map[ symbol ] 294 ) 295 296 def set_transitions( self, transitions : list[Transition] ) : 297 self._invalidate() 298 self.transitions = transitions.copy() 299 300 def extend_states( self, states : list[CausalState] ) : 301 self.set_states( self.states + states ) 302 303 def extend_alphabet( self, alphabet : list[str] ) : 304 self.set_alphabet( self.alphabet + alphabet ) 305 306 def extend_transitions( self, transitions : list[Transition] ) : 307 self.set_transitions( self.transitions + transitions ) 308 309 def get_complexity_measure_if_exists(self, measure ) : 310 m = self.complexity.get( measure, None ) 311 return m 312 313 def set_complexity_measure(self, measure, value ) : 314 self.complexity[ measure ] = value 315 316 #-------------------------------------------------------# 317 # Get and Compute # 318 #-------------------------------------------------------# 319 320 def get_complexities( 321 self, 322 with_block_convergence=False, 323 with_reversal=False ) : 324 325 directly_calculable = [ 326 self.C_mu, 327 self.h_mu, 328 self.H_1, 329 self.rho_mu 330 ] 331 332 requires_block_convergence = [ 333 self.E, 334 self.T_inf, 335 self.S, 336 self.Chi 337 ] 338 339 requires_reversal = [ 340 self.q_mu, 341 self.r_mu, 342 self.b_mu 343 ] 344 345 complexities = { m.__name__ : m() for m in directly_calculable } 346 347 if with_block_convergence : 348 349 complexities |= { m.__name__ : m() for m in requires_block_convergence } 350 351 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 352 if key in self.complexity : 353 complexities[ key ] = self.complexity[ key ] 354 355 if with_reversal : 356 complexities |= { m.__name__ : m() for m in requires_reversal } 357 358 return complexities 359 360 #-------------------------------------------------------# 361 362 def get_metadata(self) : 363 return { 364 "name" : self.name, 365 'complexity' : self.complexity, 366 "description" : self.description 367 } 368 369 def get_transition_matrix(self) : 370 371 if self.T is not None : 372 return self.T 373 374 n_states = len( self.states ) 375 T = np.zeros((n_states, n_states)) 376 377 for tr in self.transitions : 378 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 379 380 self.T = T 381 382 return self.T 383 384 #-------------------------------------------------------# 385 386 def get_T_X(self) : 387 388 if self.T_x is not None : 389 return self.T_x 390 391 n_states = len( self.states ) 392 n_symbols = len( self.alphabet ) 393 394 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 395 396 for tr in self.transitions : 397 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 398 399 self.T_x = T_x 400 return self.T_x 401 402 #-------------------------------------------------------# 403 404 def get_msp_qw( 405 self, 406 exact_state_cap: int = 1000, 407 verbose: bool = True, 408 ): 409 if self.msp is not None: 410 return self.msp 411 412 try : 413 414 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 415 416 self.msp = compute_msp_exact( 417 T_x=self.get_Tx_fractional(), 418 pi=self.get_fractional_stationary_distribution(), 419 n_states=len(self.states), 420 alphabet=self.alphabet, 421 exact_state_cap=1000, 422 bool = True 423 ) 424 425 return self.msp 426 427 except RuntimeError as e : 428 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 429 430 return self.get_msp() 431 432 def get_msp( 433 self, 434 exact_state_cap: int = 175_000, 435 jsd_eps: float = 1e-7, 436 k_ann: int = 50, 437 verbose = True, 438 ) : 439 440 if self.msp is not None: 441 return self.msp 442 443 T_x = self.get_T_X() 444 pi = self.get_stationary_distribution() 445 446 T_stacked = np.stack(T_x) 447 n_symbols = len(self.alphabet) 448 n_input_states = T_stacked.shape[1] 449 450 print( "\nComputing Mixed State Presentation\n" ) 451 452 self.msp = compute_msp( 453 T_x=T_x, 454 pi=pi, 455 n_states=len(self.states), 456 alphabet=self.alphabet, 457 exact_state_cap=exact_state_cap, 458 verbose=verbose 459 ) 460 461 return self.msp 462 463 def get_reverse_am(self) : 464 465 if self.reverse_am is not None: 466 return self.reverse_am 467 468 pi = self.get_stationary_distribution() 469 self.reverse_am = copy.deepcopy(self) 470 471 new_transitions = [] 472 for tr in self.transitions: 473 i = tr.target_state_idx 474 j = tr.origin_state_idx 475 476 p_reversed = (pi[j] * tr.prob) / pi[i] 477 478 new_transitions.append( 479 Transition( 480 origin_state_idx=i, 481 target_state_idx=j, 482 prob=p_reversed, 483 symbol_idx=tr.symbol_idx 484 ) 485 ) 486 487 self.reverse_am.set_transitions(new_transitions) 488 489 if self.reverse_am.is_epsilon_machine(): 490 return self.reverse_am 491 492 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 493 494 self.reverse_am.set_states( rmsp.states ) 495 self.reverse_am.set_transitions( rmsp.transitions ) 496 self.reverse_am.msp = rmsp 497 self.reverse_am.start_state = 0 498 499 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 500 self.reverse_am.minimize() 501 502 return self.reverse_am 503 504 #-------------------------------------------------------# 505 506 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 507 508 self.to_q_weighted() 509 510 n_states = len( self.states ) 511 n_symbols = len( self.alphabet ) 512 513 T_x = [] 514 515 for x in range( n_symbols ) : 516 T_x.append( [] ) 517 for i in range( n_states ) : 518 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 519 520 for tr in self.transitions : 521 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 522 523 return T_x 524 525 def get_T_sympy( self ) : 526 527 self.to_q_weighted() 528 529 n = len( self.states ) 530 T = sympy.zeros( n, n ) 531 532 for tr in self.transitions : 533 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq 534 535 return T 536 537 def get_fractional_stationary_distribution(self) : 538 539 T = self.get_T_sympy() 540 541 if self.pi_fractional is not None : 542 return self.pi_fractional 543 544 G = self.as_digraph() 545 546 if not nx.is_strongly_connected(G): 547 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 548 549 self.pi_fractional = solve_for_pi_fractional( T ) 550 551 return self.pi_fractional 552 553 def get_stationary_distribution(self): 554 555 if self.pi is not None : 556 return self.pi 557 558 G = self.as_digraph() 559 560 if not nx.is_strongly_connected(G): 561 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 562 563 T = self.get_transition_matrix() 564 return solve_for_pi( T ) 565 566 #-------------------------------------------------------# 567 # Complexity Measures # 568 #-------------------------------------------------------# 569 570 def C_mu( self ) : 571 572 """ 573 c_mu (statistical complexity, or forecasting complexity) 574 575 Interpretations 576 - The amount of historical information a process stores. 577 - The amount of structure in a process. 578 """ 579 580 m = self.get_complexity_measure_if_exists( "C_mu" ) 581 582 if m is not None : 583 return m 584 585 pi = self.get_stationary_distribution() 586 587 h = 0 588 for i, pr in enumerate( pi ) : 589 590 if pr < self.EPS : 591 continue 592 593 h += -pr * np.log2( pr ) 594 595 self.set_complexity_measure( "C_mu", h ) 596 597 return h 598 599 #-------------------------------------------------------# 600 601 def h_mu( self ) : 602 603 """ 604 h_mu (entropy rate) 605 606 Interpretation 607 - Intrinsic Randomness in the process. 608 """ 609 610 m = self.get_complexity_measure_if_exists( "h_mu" ) 611 612 if m is not None : 613 return m 614 615 T = self.get_transition_matrix() 616 pi = self.get_stationary_distribution() 617 618 n_states = pi.size 619 620 h = 0 621 for i, pr in enumerate( pi ) : 622 623 if pr < self.EPS : 624 continue 625 626 row_entropy = 0 627 for j in range( len( pi ) ) : 628 629 if T[ i, j ] < self.EPS : 630 continue 631 632 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 633 634 h += pr * row_entropy 635 636 self.set_complexity_measure( "h_mu", h ) 637 638 return h 639 640 #-------------------------------------------------------# 641 642 def H_1(self): 643 644 """ 645 H_1 (single-symbol entropy, or marginal entropy) 646 Interpretations: 647 - How uncertain you are about a single measurement with no context. 648 - The total information content per symbol before seeing any neighbors. 649 - The maximum possible entropy rate, achieved by a memoryless process. 650 """ 651 652 m = self.get_complexity_measure_if_exists("H_1") 653 if m is not None: 654 return m 655 656 pi = self.get_stationary_distribution() 657 T_X = self.get_T_X() # dict: symbol -> matrix 658 659 h = 0.0 660 for T_x in T_X: 661 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 662 p_sym = 0.0 663 for i, pr in enumerate(pi): 664 if pr < self.EPS: 665 continue 666 p_sym += pr * T_x[i, :].sum() 667 668 if p_sym < self.EPS: 669 continue 670 h -= p_sym * np.log2(p_sym) 671 672 self.set_complexity_measure("H_1", h) 673 return h 674 675 #-------------------------------------------------------# 676 677 def rho_mu(self) : 678 679 """ 680 rho_mu (anticipated information, or single-symbol redundancy) 681 Interpretations: 682 - How much the past reduces your uncertainty about the next symbol. 683 - The portion of each symbol's information that was already predictable. 684 - The difference between naive per-symbol uncertainty and true randomness. 685 - Equal to H_1 - h_mu: the gap closed by knowing context. 686 """ 687 688 m = self.get_complexity_measure_if_exists("rho_mu") 689 690 if m is not None: 691 return m 692 693 rho = self.H_1() - self.h_mu() 694 695 self.set_complexity_measure("rho_mu", rho) 696 697 return rho 698 699 #-------------------------------------------------------# 700 701 def block_convergence( self ) : 702 703 """ 704 block_convergence 705 - Estimates a range of complexity measures using block entropy convergence. 706 - E, S, T, T_L, H_L, h_mu_L, H_sync 707 """ 708 709 trs = [ [] for _ in range( len( self.states ) ) ] 710 for tr in self.transitions : 711 trs[ tr.origin_state_idx ].append( ( 712 tr.symbol_idx, 713 float( tr.prob ), 714 tr.target_state_idx ) ) 715 716 pi = self.get_stationary_distribution() 717 718 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 719 branches = [(1.0, list(state_dist))] 720 721 print( "\nComputing Block Entropy\n" ) 722 723 C = am_fast.block_entropy_convergence( 724 h_mu = self.h_mu(), 725 n_states = len( self.states ), 726 n_symbols = len( self.alphabet ), 727 convergence_tol = 1e-6, 728 precision = 10, 729 eps = 1e-25, 730 branches = branches, 731 trans = trs, 732 max_branches = 30_000_000 733 ) 734 735 print( "Done\n" ) 736 737 self.set_complexity_measure( f"E", C.E ) 738 self.set_complexity_measure( f"S", C.S ) 739 self.set_complexity_measure( f"T_inf", C.T ) 740 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 741 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 742 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 743 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 744 745 return C 746 747 #-------------------------------------------------------# 748 749 def E( self ) : 750 751 """ 752 E (excess entropy) 753 754 Interpretations: 755 - Mutual information from past to future I(Past;Future). 756 - How much total information from the past reduces uncertainty in the future 757 - Amount of (apparent) stored information 758 759 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 760 https://arxiv.org/abs/1309.3792 761 """ 762 763 m = self.get_complexity_measure_if_exists( "E" ) 764 765 if m is not None : 766 return m 767 768 try : 769 msp = self.get_msp() 770 E, S, T = msp.get_E_S_T() 771 self.set_complexity_measure( "E", E ) 772 self.set_complexity_measure( "S", S ) 773 self.set_complexity_measure( "T_inf", T ) 774 775 except Exception as e : 776 777 print( f"MSP failed {e}" ) 778 exit() 779 780 C = self.block_convergence() 781 E = C.E 782 self.set_complexity_measure( "E", E ) 783 784 return E 785 786 #-------------------------------------------------------# 787 788 def S( self ) : 789 790 """ 791 S (synchronization information) 792 793 Interpretation 794 -synchronization information 795 796 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 797 https://arxiv.org/abs/1309.3792 798 """ 799 800 m = self.get_complexity_measure_if_exists( "S" ) 801 802 if m is not None : 803 return m 804 805 try : 806 msp = self.get_msp() 807 E, S, T = msp.get_E_S_T() 808 self.set_complexity_measure( "E", E ) 809 self.set_complexity_measure( "S", S ) 810 self.set_complexity_measure( "T_inf", T ) 811 812 except Exception as e : 813 print( f"{e} \nFalling back to iterative estimation.") 814 exit() 815 816 C = self.block_convergence() 817 S = C.S 818 self.set_complexity_measure( "S", S ) 819 820 return S 821 822 #-------------------------------------------------------# 823 824 def T_inf( self ) : 825 826 """ 827 T_inf (transient information) 828 829 Interpretation 830 -Transient information 831 832 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 833 https://arxiv.org/abs/1309.3792 834 """ 835 836 m = self.get_complexity_measure_if_exists( "T_inf" ) 837 838 if m is not None : 839 return m 840 841 try : 842 msp = self.get_msp() 843 E, S, T = msp.get_E_S_T() 844 self.set_complexity_measure( "E", E ) 845 self.set_complexity_measure( "S", S ) 846 self.set_complexity_measure( "T_inf", T ) 847 848 except Exception as e : 849 print( f"{e} \nFalling back to iterative estimation.") 850 exit() 851 C = self.block_convergence() 852 T_inf = C.T 853 self.set_complexity_measure( "T_inf", T_inf ) 854 855 return T_inf 856 857 #-------------------------------------------------------# 858 859 def Chi( self ) : 860 861 """ 862 Chi (crypticity) 863 864 Interpretation 865 - Difference between internal stored information and apparent information to an observer 866 - How muching information is hiding in the system 867 """ 868 869 m = self.get_complexity_measure_if_exists( "Chi" ) 870 871 if m is not None : 872 return m 873 874 Chi = self.C_mu() - self.E() 875 876 self.set_complexity_measure( "Chi", Chi ) 877 878 return Chi 879 880 #-------------------------------------------------------# 881 882 def k_X( self ) : 883 884 """ 885 k_x (cryptic order) 886 887 Interpretation 888 889 """ 890 891 m = self.get_complexity_measure_if_exists( "k_X" ) 892 893 if m is not None : 894 return m 895 896 raise NotImplementedError("Not implemented") 897 898 #-------------------------------------------------------# 899 900 def R( self ) : 901 902 """ 903 R (Markov order) 904 Interpretations: 905 - The smallest L such that h_mu(L) = h_mu. 906 - The context length beyond which additional history gives no predictive benefit. 907 - The memory length of the process — how far back correlations extend. 908 - The ideal context window length for a predictor. 909 - R = 0 means the process is IID (memoryless). 910 - R = infinity means the process has unbounded memory (non-finitary). 911 — only report if convergence was clean, otherwise None 912 """ 913 914 m = self.get_complexity_measure_if_exists( "R" ) 915 916 if m is not None : 917 return m 918 919 raise NotImplementedError("Not implemented") 920 921 #-------------------------------------------------------# 922 923 def b_mu( self ) : 924 925 """ 926 b_mu - requires reverse HMM 927 928 929 Interpretation 930 931 """ 932 933 m = self.get_complexity_measure_if_exists( "b_mu" ) 934 935 if m is not None : 936 return m 937 938 raise NotImplementedError("Not implemented") 939 940 #-------------------------------------------------------# 941 942 def r_mu( self ) : 943 944 """ 945 r_mu - requires reverse HMM 946 947 948 Interpretation 949 950 """ 951 952 m = self.get_complexity_measure_if_exists( "r_mu" ) 953 954 if m is not None : 955 return m 956 957 raise NotImplementedError("Not implemented") 958 959 #-------------------------------------------------------# 960 961 def q_mu( self ) : 962 963 """ 964 q_mu 965 966 Interpretation 967 """ 968 969 m = self.get_complexity_measure_if_exists( "q_mu" ) 970 971 if m is not None : 972 return m 973 974 raise NotImplementedError("Not implemented") 975 976 #-------------------------------------------------------# 977 # Properties # 978 #-------------------------------------------------------# 979 980 def causally_reversible( self ) : 981 982 """ 983 causally_reversible 984 """ 985 986 m = self.get_complexity_measure_if_exists( "causally_reversible" ) 987 988 if m is not None : 989 return m 990 991 raise NotImplementedError("Not implemented") 992 993 #-------------------------------------------------------# 994 995 def microscopically_reversible( self ) : 996 997 """ 998 microscopically_reversible 999 """ 1000 1001 m = self.get_complexity_measure_if_exists( "microscopically_reversible" ) 1002 1003 if m is not None : 1004 return m 1005 1006 raise NotImplementedError("Not implemented") 1007 1008 #-------------------------------------------------------# 1009 1010 def is_row_stochastic(self) : 1011 1012 """ 1013 is_row_stochastic 1014 -All states have outgoing transition probabilities which sum 1 1015 """ 1016 1017 sums = np.zeros( len( self.states ) ) 1018 for tr in self.transitions : 1019 sums[ tr.origin_state_idx ] += tr.prob 1020 return np.allclose( sums, 1.0 ) 1021 1022 #-------------------------------------------------------# 1023 1024 def is_unifilar(self) : 1025 1026 """ 1027 is_unifilar 1028 -All states have no more than one outgoing transition per-symbol 1029 """ 1030 1031 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 1032 1033 for tr in self.transitions : 1034 1035 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 1036 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 1037 1038 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 1039 return False 1040 1041 return True 1042 1043 #-------------------------------------------------------# 1044 1045 def is_strongly_connected(self) : 1046 1047 """ 1048 is_strongly_connected 1049 -Every state is reachable from every other state 1050 """ 1051 1052 return nx.is_strongly_connected( self.as_digraph() ) 1053 1054 #-------------------------------------------------------# 1055 1056 def is_aperiodic(self) : 1057 1058 """ 1059 is_aperiodic 1060 -"A strongly connected directed graph is aperiodic 1061 if there is no integer k > 1 that divides the length of every cycle in the graph." 1062 """ 1063 1064 return nx.is_aperiodic( self.as_digraph() ) 1065 1066 #-------------------------------------------------------# 1067 1068 def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) : 1069 1070 with_probs = not topological_only 1071 1072 # Construct the DFA 1073 dfa = self.as_dfa( with_probs=with_probs ) 1074 1075 # Minimize the DFA 1076 #dfa = dfa.minify(retain_names=True) 1077 dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1078 1079 # check we have minimal number of states 1080 if len( dfa.states ) != len( self.states ) : 1081 if verbose : 1082 print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" ) 1083 return False 1084 1085 return True 1086 1087 def is_topological_epsilon_machine( self, verbose=True ) : 1088 1089 """ 1090 is_topological_epsilon_machine 1091 -see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024 1092 """ 1093 1094 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1095 if verbose : 1096 print( f"Either non unifilar or not strongly connected" ) 1097 return False 1098 else : 1099 return self.is_minimal_as_dfa( topological_only=True, verbose=verbose ) 1100 1101 def is_epsilon_machine( self, verbose=True ) : 1102 1103 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1104 if verbose : 1105 print( f"Either non unifilar or not strongly connected" ) 1106 return False 1107 else : 1108 return self.is_minimal_as_dfa( topological_only=False, verbose=verbose ) 1109 1110 #-------------------------------------------------------# 1111 # Properties # 1112 #-------------------------------------------------------# 1113 1114 def minimize(self, retain_names: bool = True): 1115 1116 """ 1117 minimize 1118 -Simplify to minimal form, using Myhill Nerode 1119 """ 1120 1121 if self.is_minimal : 1122 return 1123 1124 start = time.perf_counter() 1125 1126 if not self.is_unifilar(): 1127 raise ValueError( 1128 "DFA minimization is not valid for non-unifilar HMMs" 1129 ) 1130 1131 was_strongly_connected = self.is_strongly_connected() 1132 1133 was_row_stochastic = self.is_row_stochastic() 1134 n_states_before = len(self.states) 1135 1136 dfa = self.as_dfa(with_probs=True) 1137 1138 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1139 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1140 1141 # Build lookup from original state index -> CausalState object 1142 orig_state = {i: s for i, s in enumerate(self.states)} 1143 eq_list = list(min_dfa.states) 1144 1145 start_eq = min_dfa.initial_state 1146 1147 # Separate the start state, then sort the rest by the 1148 # smallest original state index inside each equivalence class. 1149 other_eqs = [eq for eq in eq_list if eq != start_eq] 1150 other_eqs.sort(key=lambda eq: min(eq)) 1151 1152 # Recombine so start eq comes first, followed by the sorted remaining classes 1153 eq_list = [start_eq] + other_eqs 1154 # ---------------------------------------------------------- 1155 1156 # Recompute eq_to_idx with the new ordering 1157 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1158 1159 # new_start is now guaranteed to be 0 1160 new_start = 0 1161 1162 # Map each original state index -> its equivalence class 1163 # Guard: minify() silently drops unreachable states 1164 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1165 1166 # Build lookup from original state index -> its transitions 1167 orig_trs = defaultdict(list) 1168 for t in self.transitions: 1169 orig_trs[t.origin_state_idx].append(t) 1170 1171 new_trs = [] 1172 for eq in min_dfa.states: 1173 rep = next(iter(eq)) 1174 origin_idx = eq_to_idx[eq] 1175 for t in orig_trs[rep]: 1176 target_eq = orig_to_eq[t.target_state_idx] 1177 target_idx = eq_to_idx[target_eq] 1178 new_trs.append(Transition( 1179 origin_state_idx = origin_idx, 1180 target_state_idx = target_idx, 1181 prob = t.prob, 1182 symbol_idx = t.symbol_idx, 1183 )) 1184 1185 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1186 1187 # Compute new names 1188 if retain_names: 1189 new_names = [ 1190 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1191 else members[0].name 1192 for members in members_list 1193 ] 1194 else: 1195 new_names = [str(j) for j in range(len(eq_list))] 1196 1197 old_name_to_new_name = { 1198 m.name: new_names[j] 1199 for j, members in enumerate(members_list) 1200 for m in members 1201 } 1202 1203 # Build the new states, preserving classes and isomorphs regardless of naming 1204 new_states = [] 1205 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1206 classes = set().union(*(m.classes for m in members)) 1207 isomorphs = { 1208 old_name_to_new_name.get(iso, iso) 1209 for m in members 1210 for iso in m.isomorphs 1211 if old_name_to_new_name.get(iso, iso) != name 1212 } 1213 new_states.append(CausalState( 1214 name = name, 1215 classes = classes, 1216 isomorphs = isomorphs, 1217 )) 1218 1219 self.set_states(new_states) 1220 self.set_transitions(new_trs) 1221 self.start_state = new_start 1222 1223 if n_states_before == len(new_states) : 1224 print( f"{n_states_before} state HMM was already minimal." ) 1225 else : 1226 print( f"Minimized from {n_states_before} to {len(new_states)}" ) 1227 1228 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1229 raise RuntimeError( 1230 f"Minimization broke strongly connected" 1231 ) 1232 1233 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1234 raise RuntimeError( 1235 f"Minimization broke row stochasticity" 1236 ) 1237 1238 self.is_minimal = True 1239 1240 #-------------------------------------------------------# 1241 # Modifiers # 1242 #-------------------------------------------------------# 1243 1244 1245 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1246 1247 # get equivalent networkx graph 1248 G = self.as_digraph() 1249 1250 # if already strongly connected, nothing to do 1251 if not nx.is_strongly_connected( G ) : 1252 1253 start = time.perf_counter() 1254 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1255 1256 # decompose into strongly connected components and sort by length 1257 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1258 subgraph_nodes.sort(key=len) 1259 component_state_set = subgraph_nodes[-1] 1260 1261 # Take the largest strongly connected component (as list of state names) 1262 component_states = sorted( list( component_state_set ) ) 1263 1264 # make temporary copies of the old transitions and states 1265 old_transitions = [ 1266 Transition( 1267 origin_state_idx=tr.origin_state_idx, 1268 target_state_idx=tr.target_state_idx, 1269 prob=tr.prob, 1270 symbol_idx=tr.symbol_idx 1271 ) 1272 1273 for tr in self.transitions 1274 ] 1275 1276 old_states = [ 1277 CausalState( 1278 name=s.name, 1279 classes=s.classes, 1280 isomorphs=s.isomorphs 1281 ) 1282 for s in self.states 1283 ] 1284 1285 self.set_states( 1286 states=[ 1287 state 1288 for i, state in enumerate( old_states ) if i in component_state_set 1289 ] 1290 ) 1291 1292 # we will build new transition list based on those belonging to the component 1293 self.set_transitions( transitions= [] ) 1294 1295 # for tracking which new transitions leave each state 1296 transitions_from_state = { state : set() for state in component_states } 1297 new_transitions = [] 1298 1299 for tr in old_transitions : 1300 1301 origin_state_name = old_states[ tr.origin_state_idx ].name 1302 target_state_name = old_states[ tr.target_state_idx ].name 1303 1304 # skip transitions that connect separate strongly connected components 1305 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1306 continue 1307 1308 # track transitions (by index in new transitions list) that leave this state 1309 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1310 1311 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1312 my_target_state_idx = self.state_idx_map[ target_state_name ] 1313 1314 new_transitions.append( 1315 Transition( 1316 origin_state_idx=my_origin_state_idx, 1317 target_state_idx=my_target_state_idx, 1318 prob=tr.prob, 1319 symbol_idx=tr.symbol_idx 1320 ) ) 1321 1322 self.set_transitions( transitions=new_transitions ) 1323 1324 # if we removed an outgoing transition from a state, we need to distribute its probability 1325 # among the remaining outgoing transitions from the state 1326 for state in component_states : 1327 1328 # get the set of transitions leaving this state 1329 state_trs = transitions_from_state[ state ] 1330 1331 # sum the probabilities of the outgoing transitions from the state 1332 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1333 1334 # how much probability is missing 1335 diff = 1.0 - p_sum 1336 1337 # if significant difference 1338 if abs( diff ) > self.EPS : 1339 1340 # calculate how much of the difference each transition gets 1341 adjustment = diff / len( state_trs ) 1342 1343 # update the transitions 1344 for i in state_trs : 1345 1346 # transition with origional probability 1347 tr = self.transitions[ i ] 1348 1349 # adjusted probability 1350 self.transitions[ i ] = Transition( 1351 origin_state_idx=tr.origin_state_idx, 1352 target_state_idx=tr.target_state_idx, 1353 prob=tr.prob + adjustment, 1354 symbol_idx=tr.symbol_idx ) 1355 1356 if rename_states : 1357 self.set_states( [ 1358 CausalState( name=f"{i}" ) 1359 for i, s in enumerate( self.states ) 1360 ] ) 1361 1362 1363 def to_q_weighted( self, denominator_limit=1000 ) : 1364 1365 """ 1366 to_q_weighted 1367 -Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq 1368 """ 1369 1370 if self.is_q_weighted : 1371 return 1372 1373 if not self.is_row_stochastic() : 1374 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1375 1376 t_from = [[] for _ in range(len(self.states))] 1377 1378 for i, tr in enumerate( self.transitions ) : 1379 t_from[ tr.origin_state_idx ].append( i ) 1380 1381 new_transitions = [] 1382 for t_list in t_from : 1383 1384 if not t_list : 1385 continue 1386 1387 p_q_sum = Fraction(0,1) 1388 p_qs = [] 1389 1390 for t_idx in t_list : 1391 1392 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1393 p_q_sum += p_q 1394 p_qs.append( p_q ) 1395 1396 if p_q_sum != Fraction(1,1) : 1397 1398 max_pq_i = np.argmax( p_qs ) 1399 max_oq = p_qs[ max_pq_i ] 1400 1401 diff = p_q_sum - Fraction(1,1) 1402 1403 # recurse with higher resolution 1404 if diff > max_oq : 1405 return self.to_q_weighted( denominator_limit*10 ) 1406 else : 1407 p_qs[ max_pq_i ] -= diff 1408 1409 for i, t_idx in enumerate( t_list ) : 1410 new_transitions.append( 1411 Transition( 1412 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1413 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1414 prob=float(p_qs[ i ]), 1415 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1416 pq=p_qs[ i ] 1417 ) 1418 ) 1419 1420 self.set_transitions( new_transitions ) 1421 self.is_q_weighted = True 1422 1423 #-------------------------------------------------------# 1424 # Data Generation # 1425 #-------------------------------------------------------# 1426 1427 def generate_data( 1428 self, 1429 file_prefix: str, 1430 n_gen: int, 1431 include_states: bool, 1432 isomorphic_shifts : set[int]=None, 1433 random_seed : int=42 ) -> dict[any] : 1434 1435 trs = [ [] for _ in range( len( self.states ) ) ] 1436 for tr in self.transitions : 1437 trs[ tr.origin_state_idx ].append( ( 1438 tr.symbol_idx, 1439 float( tr.prob ), 1440 tr.target_state_idx ) ) 1441 1442 data = am_fast.generate_data( 1443 n_gen=n_gen, 1444 start_state=self.start_state, 1445 transitions=trs, 1446 alphabet=sorted(list(self.alphabet)), 1447 include_states=include_states, 1448 random_seed=random_seed 1449 ) 1450 1451 if isomorphic_shifts is not None : 1452 1453 if not include_states : 1454 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1455 1456 data[ "isomorphic_shifts" ] = {} 1457 1458 for shift in isomorphic_shifts : 1459 1460 shifted = HMM.isomorphic_shift( 1461 symbol_indices=data[ "symbol_index" ], 1462 state_indices=data[ "state_index" ], 1463 m=self, 1464 shift=shift 1465 ) 1466 1467 data[ "isomorphic_shifts" ][ shift ] = { 1468 "symbol_index" : shifted[ "symbol_index" ], 1469 "state_index" : shifted[ "state_index" ] 1470 } 1471 1472 am_fast.save_data( 1473 data=data, 1474 file_prefix=file_prefix, 1475 alphabet=sorted(list(self.alphabet)), 1476 n_states=len( self.states ), 1477 start_state=self.start_state, 1478 random_seed=random_seed, 1479 metadata=self.get_metadata() ) 1480 1481 return data 1482 1483 #-------------------------------------------------------# 1484 # Basic Visualization # 1485 #-------------------------------------------------------# 1486 1487 def draw_graph( 1488 self, 1489 engine : str = 'dot', 1490 output_dir : Path = Path('.'), 1491 show : bool = True 1492 ) -> None : 1493 1494 G = self.as_digraph() 1495 1496 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1497 1498 am_vis.draw_graph( 1499 self, 1500 output_dir=output_dir, 1501 title="am_graph", 1502 view=show, 1503 subgraphs=subgraphs, 1504 engine=engine ) 1505 1506 #-------------------------------------------------------# 1507 # Alternate Forms # 1508 #-------------------------------------------------------# 1509 1510 1511 def as_digraph( self ) : 1512 1513 """ 1514 Build nx DiGraph based on the HMM, without edge probabilities or symbols 1515 """ 1516 1517 G = nx.DiGraph() 1518 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1519 1520 for tr in self.transitions : 1521 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1522 1523 return G 1524 1525 def as_dfa( self, with_probs : bool ) : 1526 1527 precision=8 1528 1529 def edge_label( symb, prob ) : 1530 return f"({symb},{round(prob, precision)})" 1531 1532 # Build states, symbols, and transitions 1533 dfa_states = { i for i, _ in enumerate( self.states ) } 1534 1535 if not with_probs : 1536 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1537 else : 1538 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1539 1540 dfa_transitions = defaultdict(dict) 1541 1542 if not with_probs : 1543 for t in self.transitions : 1544 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1545 else : 1546 for t in self.transitions : 1547 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1548 1549 # Construct the DFA 1550 return DFA( 1551 states=dfa_states, 1552 input_symbols=dfa_symbols, 1553 transitions=dfa_transitions, 1554 initial_state=self.start_state, 1555 allow_partial=True, 1556 final_states={ 1557 s for s in dfa_states 1558 } 1559 )
38class HMM(Machine) : 39 40 @staticmethod 41 def isomorphic_shift( 42 input_symbol_indices: np.ndarray, 43 input_state_indices: np.ndarray, 44 m : HMM, 45 shift : int = 1 46 ) -> dict[str, np.ndarray]: 47 48 if not any( state.isomorphs for state in m.states ): 49 raise ValueError("HMM has no states with isomorphs") 50 51 inputs = np.asarray(input_symbol_indices) 52 states = np.asarray(input_state_indices) 53 54 n_states = len(m.states) 55 56 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 57 58 for tr in m.transitions: 59 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 60 61 # Build isomorph remapping: identity by default, overridden where isomorphs exist 62 iso_table = np.arange(n_states, dtype=np.int32) 63 64 for i, state in enumerate(m.states): 65 if state.isomorphs: 66 isomorph = sorted(state.isomorphs)[0] 67 iso_table[ i ] = m.state_idx_map[isomorph] 68 69 for i, state in enumerate(m.states): 70 if state.isomorphs: 71 72 # extend the isomorph list to include the identity 73 isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] ) 74 75 # find the identity index 76 pos = isormorphs_with_identity.index( i ) 77 78 # cyclical shift 79 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 80 81 origins = states[:-1] 82 targets = states[1:] 83 84 out_origins = iso_table[origins] 85 out_targets = iso_table[targets] 86 87 inv_sym = tr_sym_table[ out_origins, out_targets ] 88 inv_sts = np.empty(states.size, dtype=states.dtype) 89 90 inv_sts[:-1] = out_origins 91 inv_sts[-1] = out_targets[-1] 92 93 return { 94 "symbol_index": inv_sym.astype(inputs.dtype), 95 "state_index": inv_sts, 96 } 97 98 def __init__( 99 self, 100 states : list[CausalState] | None = None, 101 transitions : list[Transition] | None = None, 102 start_state : int = 0, 103 alphabet : list[str] | None = None, 104 isoclasses : dict[str,set[str]] | None = None, 105 name : str = "", 106 description : str = "" ) : 107 108 self.alphabet = alphabet or [] 109 self.states = states or [] 110 self.transitions = transitions or [] 111 112 self.set_alphabet( self.alphabet ) 113 self.set_states( self.states ) 114 self.set_transitions( self.transitions ) 115 116 self.start_state = start_state 117 self.isoclasses = isoclasses or {} 118 119 self.name = name 120 self.description = description 121 122 # to be depreciated 123 self.isoclass = None 124 125 # --- derived -------- 126 127 self.complexity : dict[str, any] = {} 128 129 self.symbol_idx_map : dict[str,int] = {} 130 self.state_idx_map : dict[str,int] = {} 131 132 self.pi_fractional = None 133 self.pi = None 134 self.T = None 135 self.msp : MSP = None 136 self.reverse_am : HMM = None 137 self.is_q_weighted = False 138 self.is_minimal = False 139 140 # --- const ---------- 141 142 self.EPS = 1e-12 143 144 #-------------------------------------------------------# 145 # Overrides / Getters # 146 #-------------------------------------------------------# 147 148 @override 149 def get_states(self) -> list[CausalStates] : 150 return self.states 151 152 @override 153 def get_transitions(self) -> list[Transition] : 154 return self.transitions 155 156 @override 157 def get_alphabet(self) -> list[Symbol]: 158 return [ Symbol(a) for a in self.alphabet ] 159 160 #-------------------------------------------------------# 161 # Serialization # 162 #-------------------------------------------------------# 163 164 165 def to_dict(self) : 166 return { 167 "name" : self.name, 168 "description" : self.description, 169 "states" : [ asdict(state) for state in self.states ], 170 "transitions" : [ asdict(transition) for transition in self.transitions ], 171 "alphabet" : self.alphabet 172 } 173 174 def from_dict( self, config ) : 175 176 self.name = config.name 177 self.description = config.description 178 179 self.set_states( states=[ 180 CausalState( 181 name=state[ "name" ], 182 classes=set( state[ "classes" ] ) 183 ) 184 for state in config.states 185 ] ) 186 187 self.set_transitions( transitions=[ 188 Transition( 189 origin_state_idx=tr[ "origin_state_idx" ], 190 target_state_idx=tr[ "target_state_idx" ], 191 prob=tr[ "prob" ], 192 symbol_idx=tr[ "symbol_idx" ] 193 ) 194 for tr in config.transitions 195 ] ) 196 197 self.set_alphabet( alphabet=config.alphabet ) 198 199 def save_config( 200 self, 201 path : Path, 202 with_complexity : bool = False, 203 with_block_convergence : bool = False, 204 with_structural_properties : bool = False, 205 with_causal_properties : bool = False ) : 206 207 config = self.to_dict() 208 209 if with_complexity : 210 211 complexity = self.get_complexities( 212 with_block_convergence=with_block_convergence, 213 with_reversal=False # not implemented 214 ) 215 216 config[ "complexity" ] = complexity 217 218 config[ "structural_properties" ] = { 219 "unifilar" : self.is_unifilar(), 220 "row_stochastic" : self.is_row_stochastic(), 221 "strongly_connected" : self.is_strongly_connected(), 222 "aperiodic" : self.is_aperiodic(), 223 "minimal" : self.is_minimal_as_dfa( topological_only=False ), 224 "topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ), 225 "is_epsilon_machine" : self.is_epsilon_machine() 226 } 227 228 # Not implemented 229 # config[ "causal_properties" ] = { 230 # "k_X" : self.k_X(), 231 # "R" : self.R(), 232 # "causally_reversible" : self.causally_reversible(), 233 # "microscopically_reversible" : self.microscopically_reversible(), 234 # } 235 236 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 237 json.dump( config, f, ensure_ascii=False, indent=2, default=list ) 238 239 def from_file( self, path : Path ) : 240 with open( Path / "am_config.json", "r" ) as f: 241 config = json.load(f) 242 self.from_dict() 243 244 #-------------------------------------------------------# 245 # Setters and State Management # 246 #-------------------------------------------------------# 247 248 def _invalidate(self) : 249 250 """ 251 Invalidate 252 253 Reset all derived properties so they will be recomputed when requested later 254 """ 255 256 self.complexity = {} 257 self.T = None 258 self.T_x = None 259 self.pi = None 260 self.pi_fractional = None 261 self.msp = None 262 self.reverse_am = None 263 self.is_q_weighted = False 264 self.is_minimal = False 265 266 def set_states( self, states : list[CausalState] ) : 267 self._invalidate() 268 self.states = states.copy() 269 self.state_idx_map = {} 270 for idx, state in enumerate( self.states ) : 271 self.state_idx_map[ state.name ] = idx 272 273 def set_alphabet( self, alphabet : list[str] ) : 274 275 self._invalidate() 276 277 old_alphabet = self.alphabet.copy() 278 279 aSet = set() 280 aSet.update( alphabet ) 281 282 self.alphabet = sorted(list(aSet)) 283 284 self.symbol_idx_map = {} 285 for idx, symbol in enumerate( self.alphabet ) : 286 self.symbol_idx_map[ symbol ] = idx 287 288 for i, tr in enumerate( self.transitions ) : 289 symbol = old_alphabet[ tr.symbol_idx ] 290 self.transitions[ i ] = Transition( 291 origin_state_idx=tr.origin_state_idx, 292 target_state_idx=tr.target_state_idx, 293 prob=tr.prob, 294 symbol_idx=self.symbol_idx_map[ symbol ] 295 ) 296 297 def set_transitions( self, transitions : list[Transition] ) : 298 self._invalidate() 299 self.transitions = transitions.copy() 300 301 def extend_states( self, states : list[CausalState] ) : 302 self.set_states( self.states + states ) 303 304 def extend_alphabet( self, alphabet : list[str] ) : 305 self.set_alphabet( self.alphabet + alphabet ) 306 307 def extend_transitions( self, transitions : list[Transition] ) : 308 self.set_transitions( self.transitions + transitions ) 309 310 def get_complexity_measure_if_exists(self, measure ) : 311 m = self.complexity.get( measure, None ) 312 return m 313 314 def set_complexity_measure(self, measure, value ) : 315 self.complexity[ measure ] = value 316 317 #-------------------------------------------------------# 318 # Get and Compute # 319 #-------------------------------------------------------# 320 321 def get_complexities( 322 self, 323 with_block_convergence=False, 324 with_reversal=False ) : 325 326 directly_calculable = [ 327 self.C_mu, 328 self.h_mu, 329 self.H_1, 330 self.rho_mu 331 ] 332 333 requires_block_convergence = [ 334 self.E, 335 self.T_inf, 336 self.S, 337 self.Chi 338 ] 339 340 requires_reversal = [ 341 self.q_mu, 342 self.r_mu, 343 self.b_mu 344 ] 345 346 complexities = { m.__name__ : m() for m in directly_calculable } 347 348 if with_block_convergence : 349 350 complexities |= { m.__name__ : m() for m in requires_block_convergence } 351 352 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 353 if key in self.complexity : 354 complexities[ key ] = self.complexity[ key ] 355 356 if with_reversal : 357 complexities |= { m.__name__ : m() for m in requires_reversal } 358 359 return complexities 360 361 #-------------------------------------------------------# 362 363 def get_metadata(self) : 364 return { 365 "name" : self.name, 366 'complexity' : self.complexity, 367 "description" : self.description 368 } 369 370 def get_transition_matrix(self) : 371 372 if self.T is not None : 373 return self.T 374 375 n_states = len( self.states ) 376 T = np.zeros((n_states, n_states)) 377 378 for tr in self.transitions : 379 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 380 381 self.T = T 382 383 return self.T 384 385 #-------------------------------------------------------# 386 387 def get_T_X(self) : 388 389 if self.T_x is not None : 390 return self.T_x 391 392 n_states = len( self.states ) 393 n_symbols = len( self.alphabet ) 394 395 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 396 397 for tr in self.transitions : 398 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 399 400 self.T_x = T_x 401 return self.T_x 402 403 #-------------------------------------------------------# 404 405 def get_msp_qw( 406 self, 407 exact_state_cap: int = 1000, 408 verbose: bool = True, 409 ): 410 if self.msp is not None: 411 return self.msp 412 413 try : 414 415 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 416 417 self.msp = compute_msp_exact( 418 T_x=self.get_Tx_fractional(), 419 pi=self.get_fractional_stationary_distribution(), 420 n_states=len(self.states), 421 alphabet=self.alphabet, 422 exact_state_cap=1000, 423 bool = True 424 ) 425 426 return self.msp 427 428 except RuntimeError as e : 429 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 430 431 return self.get_msp() 432 433 def get_msp( 434 self, 435 exact_state_cap: int = 175_000, 436 jsd_eps: float = 1e-7, 437 k_ann: int = 50, 438 verbose = True, 439 ) : 440 441 if self.msp is not None: 442 return self.msp 443 444 T_x = self.get_T_X() 445 pi = self.get_stationary_distribution() 446 447 T_stacked = np.stack(T_x) 448 n_symbols = len(self.alphabet) 449 n_input_states = T_stacked.shape[1] 450 451 print( "\nComputing Mixed State Presentation\n" ) 452 453 self.msp = compute_msp( 454 T_x=T_x, 455 pi=pi, 456 n_states=len(self.states), 457 alphabet=self.alphabet, 458 exact_state_cap=exact_state_cap, 459 verbose=verbose 460 ) 461 462 return self.msp 463 464 def get_reverse_am(self) : 465 466 if self.reverse_am is not None: 467 return self.reverse_am 468 469 pi = self.get_stationary_distribution() 470 self.reverse_am = copy.deepcopy(self) 471 472 new_transitions = [] 473 for tr in self.transitions: 474 i = tr.target_state_idx 475 j = tr.origin_state_idx 476 477 p_reversed = (pi[j] * tr.prob) / pi[i] 478 479 new_transitions.append( 480 Transition( 481 origin_state_idx=i, 482 target_state_idx=j, 483 prob=p_reversed, 484 symbol_idx=tr.symbol_idx 485 ) 486 ) 487 488 self.reverse_am.set_transitions(new_transitions) 489 490 if self.reverse_am.is_epsilon_machine(): 491 return self.reverse_am 492 493 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 494 495 self.reverse_am.set_states( rmsp.states ) 496 self.reverse_am.set_transitions( rmsp.transitions ) 497 self.reverse_am.msp = rmsp 498 self.reverse_am.start_state = 0 499 500 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 501 self.reverse_am.minimize() 502 503 return self.reverse_am 504 505 #-------------------------------------------------------# 506 507 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 508 509 self.to_q_weighted() 510 511 n_states = len( self.states ) 512 n_symbols = len( self.alphabet ) 513 514 T_x = [] 515 516 for x in range( n_symbols ) : 517 T_x.append( [] ) 518 for i in range( n_states ) : 519 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 520 521 for tr in self.transitions : 522 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 523 524 return T_x 525 526 def get_T_sympy( self ) : 527 528 self.to_q_weighted() 529 530 n = len( self.states ) 531 T = sympy.zeros( n, n ) 532 533 for tr in self.transitions : 534 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq 535 536 return T 537 538 def get_fractional_stationary_distribution(self) : 539 540 T = self.get_T_sympy() 541 542 if self.pi_fractional is not None : 543 return self.pi_fractional 544 545 G = self.as_digraph() 546 547 if not nx.is_strongly_connected(G): 548 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 549 550 self.pi_fractional = solve_for_pi_fractional( T ) 551 552 return self.pi_fractional 553 554 def get_stationary_distribution(self): 555 556 if self.pi is not None : 557 return self.pi 558 559 G = self.as_digraph() 560 561 if not nx.is_strongly_connected(G): 562 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 563 564 T = self.get_transition_matrix() 565 return solve_for_pi( T ) 566 567 #-------------------------------------------------------# 568 # Complexity Measures # 569 #-------------------------------------------------------# 570 571 def C_mu( self ) : 572 573 """ 574 c_mu (statistical complexity, or forecasting complexity) 575 576 Interpretations 577 - The amount of historical information a process stores. 578 - The amount of structure in a process. 579 """ 580 581 m = self.get_complexity_measure_if_exists( "C_mu" ) 582 583 if m is not None : 584 return m 585 586 pi = self.get_stationary_distribution() 587 588 h = 0 589 for i, pr in enumerate( pi ) : 590 591 if pr < self.EPS : 592 continue 593 594 h += -pr * np.log2( pr ) 595 596 self.set_complexity_measure( "C_mu", h ) 597 598 return h 599 600 #-------------------------------------------------------# 601 602 def h_mu( self ) : 603 604 """ 605 h_mu (entropy rate) 606 607 Interpretation 608 - Intrinsic Randomness in the process. 609 """ 610 611 m = self.get_complexity_measure_if_exists( "h_mu" ) 612 613 if m is not None : 614 return m 615 616 T = self.get_transition_matrix() 617 pi = self.get_stationary_distribution() 618 619 n_states = pi.size 620 621 h = 0 622 for i, pr in enumerate( pi ) : 623 624 if pr < self.EPS : 625 continue 626 627 row_entropy = 0 628 for j in range( len( pi ) ) : 629 630 if T[ i, j ] < self.EPS : 631 continue 632 633 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 634 635 h += pr * row_entropy 636 637 self.set_complexity_measure( "h_mu", h ) 638 639 return h 640 641 #-------------------------------------------------------# 642 643 def H_1(self): 644 645 """ 646 H_1 (single-symbol entropy, or marginal entropy) 647 Interpretations: 648 - How uncertain you are about a single measurement with no context. 649 - The total information content per symbol before seeing any neighbors. 650 - The maximum possible entropy rate, achieved by a memoryless process. 651 """ 652 653 m = self.get_complexity_measure_if_exists("H_1") 654 if m is not None: 655 return m 656 657 pi = self.get_stationary_distribution() 658 T_X = self.get_T_X() # dict: symbol -> matrix 659 660 h = 0.0 661 for T_x in T_X: 662 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 663 p_sym = 0.0 664 for i, pr in enumerate(pi): 665 if pr < self.EPS: 666 continue 667 p_sym += pr * T_x[i, :].sum() 668 669 if p_sym < self.EPS: 670 continue 671 h -= p_sym * np.log2(p_sym) 672 673 self.set_complexity_measure("H_1", h) 674 return h 675 676 #-------------------------------------------------------# 677 678 def rho_mu(self) : 679 680 """ 681 rho_mu (anticipated information, or single-symbol redundancy) 682 Interpretations: 683 - How much the past reduces your uncertainty about the next symbol. 684 - The portion of each symbol's information that was already predictable. 685 - The difference between naive per-symbol uncertainty and true randomness. 686 - Equal to H_1 - h_mu: the gap closed by knowing context. 687 """ 688 689 m = self.get_complexity_measure_if_exists("rho_mu") 690 691 if m is not None: 692 return m 693 694 rho = self.H_1() - self.h_mu() 695 696 self.set_complexity_measure("rho_mu", rho) 697 698 return rho 699 700 #-------------------------------------------------------# 701 702 def block_convergence( self ) : 703 704 """ 705 block_convergence 706 - Estimates a range of complexity measures using block entropy convergence. 707 - E, S, T, T_L, H_L, h_mu_L, H_sync 708 """ 709 710 trs = [ [] for _ in range( len( self.states ) ) ] 711 for tr in self.transitions : 712 trs[ tr.origin_state_idx ].append( ( 713 tr.symbol_idx, 714 float( tr.prob ), 715 tr.target_state_idx ) ) 716 717 pi = self.get_stationary_distribution() 718 719 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 720 branches = [(1.0, list(state_dist))] 721 722 print( "\nComputing Block Entropy\n" ) 723 724 C = am_fast.block_entropy_convergence( 725 h_mu = self.h_mu(), 726 n_states = len( self.states ), 727 n_symbols = len( self.alphabet ), 728 convergence_tol = 1e-6, 729 precision = 10, 730 eps = 1e-25, 731 branches = branches, 732 trans = trs, 733 max_branches = 30_000_000 734 ) 735 736 print( "Done\n" ) 737 738 self.set_complexity_measure( f"E", C.E ) 739 self.set_complexity_measure( f"S", C.S ) 740 self.set_complexity_measure( f"T_inf", C.T ) 741 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 742 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 743 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 744 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 745 746 return C 747 748 #-------------------------------------------------------# 749 750 def E( self ) : 751 752 """ 753 E (excess entropy) 754 755 Interpretations: 756 - Mutual information from past to future I(Past;Future). 757 - How much total information from the past reduces uncertainty in the future 758 - Amount of (apparent) stored information 759 760 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 761 https://arxiv.org/abs/1309.3792 762 """ 763 764 m = self.get_complexity_measure_if_exists( "E" ) 765 766 if m is not None : 767 return m 768 769 try : 770 msp = self.get_msp() 771 E, S, T = msp.get_E_S_T() 772 self.set_complexity_measure( "E", E ) 773 self.set_complexity_measure( "S", S ) 774 self.set_complexity_measure( "T_inf", T ) 775 776 except Exception as e : 777 778 print( f"MSP failed {e}" ) 779 exit() 780 781 C = self.block_convergence() 782 E = C.E 783 self.set_complexity_measure( "E", E ) 784 785 return E 786 787 #-------------------------------------------------------# 788 789 def S( self ) : 790 791 """ 792 S (synchronization information) 793 794 Interpretation 795 -synchronization information 796 797 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 798 https://arxiv.org/abs/1309.3792 799 """ 800 801 m = self.get_complexity_measure_if_exists( "S" ) 802 803 if m is not None : 804 return m 805 806 try : 807 msp = self.get_msp() 808 E, S, T = msp.get_E_S_T() 809 self.set_complexity_measure( "E", E ) 810 self.set_complexity_measure( "S", S ) 811 self.set_complexity_measure( "T_inf", T ) 812 813 except Exception as e : 814 print( f"{e} \nFalling back to iterative estimation.") 815 exit() 816 817 C = self.block_convergence() 818 S = C.S 819 self.set_complexity_measure( "S", S ) 820 821 return S 822 823 #-------------------------------------------------------# 824 825 def T_inf( self ) : 826 827 """ 828 T_inf (transient information) 829 830 Interpretation 831 -Transient information 832 833 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 834 https://arxiv.org/abs/1309.3792 835 """ 836 837 m = self.get_complexity_measure_if_exists( "T_inf" ) 838 839 if m is not None : 840 return m 841 842 try : 843 msp = self.get_msp() 844 E, S, T = msp.get_E_S_T() 845 self.set_complexity_measure( "E", E ) 846 self.set_complexity_measure( "S", S ) 847 self.set_complexity_measure( "T_inf", T ) 848 849 except Exception as e : 850 print( f"{e} \nFalling back to iterative estimation.") 851 exit() 852 C = self.block_convergence() 853 T_inf = C.T 854 self.set_complexity_measure( "T_inf", T_inf ) 855 856 return T_inf 857 858 #-------------------------------------------------------# 859 860 def Chi( self ) : 861 862 """ 863 Chi (crypticity) 864 865 Interpretation 866 - Difference between internal stored information and apparent information to an observer 867 - How muching information is hiding in the system 868 """ 869 870 m = self.get_complexity_measure_if_exists( "Chi" ) 871 872 if m is not None : 873 return m 874 875 Chi = self.C_mu() - self.E() 876 877 self.set_complexity_measure( "Chi", Chi ) 878 879 return Chi 880 881 #-------------------------------------------------------# 882 883 def k_X( self ) : 884 885 """ 886 k_x (cryptic order) 887 888 Interpretation 889 890 """ 891 892 m = self.get_complexity_measure_if_exists( "k_X" ) 893 894 if m is not None : 895 return m 896 897 raise NotImplementedError("Not implemented") 898 899 #-------------------------------------------------------# 900 901 def R( self ) : 902 903 """ 904 R (Markov order) 905 Interpretations: 906 - The smallest L such that h_mu(L) = h_mu. 907 - The context length beyond which additional history gives no predictive benefit. 908 - The memory length of the process — how far back correlations extend. 909 - The ideal context window length for a predictor. 910 - R = 0 means the process is IID (memoryless). 911 - R = infinity means the process has unbounded memory (non-finitary). 912 — only report if convergence was clean, otherwise None 913 """ 914 915 m = self.get_complexity_measure_if_exists( "R" ) 916 917 if m is not None : 918 return m 919 920 raise NotImplementedError("Not implemented") 921 922 #-------------------------------------------------------# 923 924 def b_mu( self ) : 925 926 """ 927 b_mu - requires reverse HMM 928 929 930 Interpretation 931 932 """ 933 934 m = self.get_complexity_measure_if_exists( "b_mu" ) 935 936 if m is not None : 937 return m 938 939 raise NotImplementedError("Not implemented") 940 941 #-------------------------------------------------------# 942 943 def r_mu( self ) : 944 945 """ 946 r_mu - requires reverse HMM 947 948 949 Interpretation 950 951 """ 952 953 m = self.get_complexity_measure_if_exists( "r_mu" ) 954 955 if m is not None : 956 return m 957 958 raise NotImplementedError("Not implemented") 959 960 #-------------------------------------------------------# 961 962 def q_mu( self ) : 963 964 """ 965 q_mu 966 967 Interpretation 968 """ 969 970 m = self.get_complexity_measure_if_exists( "q_mu" ) 971 972 if m is not None : 973 return m 974 975 raise NotImplementedError("Not implemented") 976 977 #-------------------------------------------------------# 978 # Properties # 979 #-------------------------------------------------------# 980 981 def causally_reversible( self ) : 982 983 """ 984 causally_reversible 985 """ 986 987 m = self.get_complexity_measure_if_exists( "causally_reversible" ) 988 989 if m is not None : 990 return m 991 992 raise NotImplementedError("Not implemented") 993 994 #-------------------------------------------------------# 995 996 def microscopically_reversible( self ) : 997 998 """ 999 microscopically_reversible 1000 """ 1001 1002 m = self.get_complexity_measure_if_exists( "microscopically_reversible" ) 1003 1004 if m is not None : 1005 return m 1006 1007 raise NotImplementedError("Not implemented") 1008 1009 #-------------------------------------------------------# 1010 1011 def is_row_stochastic(self) : 1012 1013 """ 1014 is_row_stochastic 1015 -All states have outgoing transition probabilities which sum 1 1016 """ 1017 1018 sums = np.zeros( len( self.states ) ) 1019 for tr in self.transitions : 1020 sums[ tr.origin_state_idx ] += tr.prob 1021 return np.allclose( sums, 1.0 ) 1022 1023 #-------------------------------------------------------# 1024 1025 def is_unifilar(self) : 1026 1027 """ 1028 is_unifilar 1029 -All states have no more than one outgoing transition per-symbol 1030 """ 1031 1032 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 1033 1034 for tr in self.transitions : 1035 1036 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 1037 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 1038 1039 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 1040 return False 1041 1042 return True 1043 1044 #-------------------------------------------------------# 1045 1046 def is_strongly_connected(self) : 1047 1048 """ 1049 is_strongly_connected 1050 -Every state is reachable from every other state 1051 """ 1052 1053 return nx.is_strongly_connected( self.as_digraph() ) 1054 1055 #-------------------------------------------------------# 1056 1057 def is_aperiodic(self) : 1058 1059 """ 1060 is_aperiodic 1061 -"A strongly connected directed graph is aperiodic 1062 if there is no integer k > 1 that divides the length of every cycle in the graph." 1063 """ 1064 1065 return nx.is_aperiodic( self.as_digraph() ) 1066 1067 #-------------------------------------------------------# 1068 1069 def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) : 1070 1071 with_probs = not topological_only 1072 1073 # Construct the DFA 1074 dfa = self.as_dfa( with_probs=with_probs ) 1075 1076 # Minimize the DFA 1077 #dfa = dfa.minify(retain_names=True) 1078 dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1079 1080 # check we have minimal number of states 1081 if len( dfa.states ) != len( self.states ) : 1082 if verbose : 1083 print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" ) 1084 return False 1085 1086 return True 1087 1088 def is_topological_epsilon_machine( self, verbose=True ) : 1089 1090 """ 1091 is_topological_epsilon_machine 1092 -see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024 1093 """ 1094 1095 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1096 if verbose : 1097 print( f"Either non unifilar or not strongly connected" ) 1098 return False 1099 else : 1100 return self.is_minimal_as_dfa( topological_only=True, verbose=verbose ) 1101 1102 def is_epsilon_machine( self, verbose=True ) : 1103 1104 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1105 if verbose : 1106 print( f"Either non unifilar or not strongly connected" ) 1107 return False 1108 else : 1109 return self.is_minimal_as_dfa( topological_only=False, verbose=verbose ) 1110 1111 #-------------------------------------------------------# 1112 # Properties # 1113 #-------------------------------------------------------# 1114 1115 def minimize(self, retain_names: bool = True): 1116 1117 """ 1118 minimize 1119 -Simplify to minimal form, using Myhill Nerode 1120 """ 1121 1122 if self.is_minimal : 1123 return 1124 1125 start = time.perf_counter() 1126 1127 if not self.is_unifilar(): 1128 raise ValueError( 1129 "DFA minimization is not valid for non-unifilar HMMs" 1130 ) 1131 1132 was_strongly_connected = self.is_strongly_connected() 1133 1134 was_row_stochastic = self.is_row_stochastic() 1135 n_states_before = len(self.states) 1136 1137 dfa = self.as_dfa(with_probs=True) 1138 1139 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1140 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1141 1142 # Build lookup from original state index -> CausalState object 1143 orig_state = {i: s for i, s in enumerate(self.states)} 1144 eq_list = list(min_dfa.states) 1145 1146 start_eq = min_dfa.initial_state 1147 1148 # Separate the start state, then sort the rest by the 1149 # smallest original state index inside each equivalence class. 1150 other_eqs = [eq for eq in eq_list if eq != start_eq] 1151 other_eqs.sort(key=lambda eq: min(eq)) 1152 1153 # Recombine so start eq comes first, followed by the sorted remaining classes 1154 eq_list = [start_eq] + other_eqs 1155 # ---------------------------------------------------------- 1156 1157 # Recompute eq_to_idx with the new ordering 1158 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1159 1160 # new_start is now guaranteed to be 0 1161 new_start = 0 1162 1163 # Map each original state index -> its equivalence class 1164 # Guard: minify() silently drops unreachable states 1165 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1166 1167 # Build lookup from original state index -> its transitions 1168 orig_trs = defaultdict(list) 1169 for t in self.transitions: 1170 orig_trs[t.origin_state_idx].append(t) 1171 1172 new_trs = [] 1173 for eq in min_dfa.states: 1174 rep = next(iter(eq)) 1175 origin_idx = eq_to_idx[eq] 1176 for t in orig_trs[rep]: 1177 target_eq = orig_to_eq[t.target_state_idx] 1178 target_idx = eq_to_idx[target_eq] 1179 new_trs.append(Transition( 1180 origin_state_idx = origin_idx, 1181 target_state_idx = target_idx, 1182 prob = t.prob, 1183 symbol_idx = t.symbol_idx, 1184 )) 1185 1186 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1187 1188 # Compute new names 1189 if retain_names: 1190 new_names = [ 1191 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1192 else members[0].name 1193 for members in members_list 1194 ] 1195 else: 1196 new_names = [str(j) for j in range(len(eq_list))] 1197 1198 old_name_to_new_name = { 1199 m.name: new_names[j] 1200 for j, members in enumerate(members_list) 1201 for m in members 1202 } 1203 1204 # Build the new states, preserving classes and isomorphs regardless of naming 1205 new_states = [] 1206 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1207 classes = set().union(*(m.classes for m in members)) 1208 isomorphs = { 1209 old_name_to_new_name.get(iso, iso) 1210 for m in members 1211 for iso in m.isomorphs 1212 if old_name_to_new_name.get(iso, iso) != name 1213 } 1214 new_states.append(CausalState( 1215 name = name, 1216 classes = classes, 1217 isomorphs = isomorphs, 1218 )) 1219 1220 self.set_states(new_states) 1221 self.set_transitions(new_trs) 1222 self.start_state = new_start 1223 1224 if n_states_before == len(new_states) : 1225 print( f"{n_states_before} state HMM was already minimal." ) 1226 else : 1227 print( f"Minimized from {n_states_before} to {len(new_states)}" ) 1228 1229 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1230 raise RuntimeError( 1231 f"Minimization broke strongly connected" 1232 ) 1233 1234 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1235 raise RuntimeError( 1236 f"Minimization broke row stochasticity" 1237 ) 1238 1239 self.is_minimal = True 1240 1241 #-------------------------------------------------------# 1242 # Modifiers # 1243 #-------------------------------------------------------# 1244 1245 1246 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1247 1248 # get equivalent networkx graph 1249 G = self.as_digraph() 1250 1251 # if already strongly connected, nothing to do 1252 if not nx.is_strongly_connected( G ) : 1253 1254 start = time.perf_counter() 1255 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1256 1257 # decompose into strongly connected components and sort by length 1258 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1259 subgraph_nodes.sort(key=len) 1260 component_state_set = subgraph_nodes[-1] 1261 1262 # Take the largest strongly connected component (as list of state names) 1263 component_states = sorted( list( component_state_set ) ) 1264 1265 # make temporary copies of the old transitions and states 1266 old_transitions = [ 1267 Transition( 1268 origin_state_idx=tr.origin_state_idx, 1269 target_state_idx=tr.target_state_idx, 1270 prob=tr.prob, 1271 symbol_idx=tr.symbol_idx 1272 ) 1273 1274 for tr in self.transitions 1275 ] 1276 1277 old_states = [ 1278 CausalState( 1279 name=s.name, 1280 classes=s.classes, 1281 isomorphs=s.isomorphs 1282 ) 1283 for s in self.states 1284 ] 1285 1286 self.set_states( 1287 states=[ 1288 state 1289 for i, state in enumerate( old_states ) if i in component_state_set 1290 ] 1291 ) 1292 1293 # we will build new transition list based on those belonging to the component 1294 self.set_transitions( transitions= [] ) 1295 1296 # for tracking which new transitions leave each state 1297 transitions_from_state = { state : set() for state in component_states } 1298 new_transitions = [] 1299 1300 for tr in old_transitions : 1301 1302 origin_state_name = old_states[ tr.origin_state_idx ].name 1303 target_state_name = old_states[ tr.target_state_idx ].name 1304 1305 # skip transitions that connect separate strongly connected components 1306 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1307 continue 1308 1309 # track transitions (by index in new transitions list) that leave this state 1310 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1311 1312 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1313 my_target_state_idx = self.state_idx_map[ target_state_name ] 1314 1315 new_transitions.append( 1316 Transition( 1317 origin_state_idx=my_origin_state_idx, 1318 target_state_idx=my_target_state_idx, 1319 prob=tr.prob, 1320 symbol_idx=tr.symbol_idx 1321 ) ) 1322 1323 self.set_transitions( transitions=new_transitions ) 1324 1325 # if we removed an outgoing transition from a state, we need to distribute its probability 1326 # among the remaining outgoing transitions from the state 1327 for state in component_states : 1328 1329 # get the set of transitions leaving this state 1330 state_trs = transitions_from_state[ state ] 1331 1332 # sum the probabilities of the outgoing transitions from the state 1333 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1334 1335 # how much probability is missing 1336 diff = 1.0 - p_sum 1337 1338 # if significant difference 1339 if abs( diff ) > self.EPS : 1340 1341 # calculate how much of the difference each transition gets 1342 adjustment = diff / len( state_trs ) 1343 1344 # update the transitions 1345 for i in state_trs : 1346 1347 # transition with origional probability 1348 tr = self.transitions[ i ] 1349 1350 # adjusted probability 1351 self.transitions[ i ] = Transition( 1352 origin_state_idx=tr.origin_state_idx, 1353 target_state_idx=tr.target_state_idx, 1354 prob=tr.prob + adjustment, 1355 symbol_idx=tr.symbol_idx ) 1356 1357 if rename_states : 1358 self.set_states( [ 1359 CausalState( name=f"{i}" ) 1360 for i, s in enumerate( self.states ) 1361 ] ) 1362 1363 1364 def to_q_weighted( self, denominator_limit=1000 ) : 1365 1366 """ 1367 to_q_weighted 1368 -Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq 1369 """ 1370 1371 if self.is_q_weighted : 1372 return 1373 1374 if not self.is_row_stochastic() : 1375 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1376 1377 t_from = [[] for _ in range(len(self.states))] 1378 1379 for i, tr in enumerate( self.transitions ) : 1380 t_from[ tr.origin_state_idx ].append( i ) 1381 1382 new_transitions = [] 1383 for t_list in t_from : 1384 1385 if not t_list : 1386 continue 1387 1388 p_q_sum = Fraction(0,1) 1389 p_qs = [] 1390 1391 for t_idx in t_list : 1392 1393 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1394 p_q_sum += p_q 1395 p_qs.append( p_q ) 1396 1397 if p_q_sum != Fraction(1,1) : 1398 1399 max_pq_i = np.argmax( p_qs ) 1400 max_oq = p_qs[ max_pq_i ] 1401 1402 diff = p_q_sum - Fraction(1,1) 1403 1404 # recurse with higher resolution 1405 if diff > max_oq : 1406 return self.to_q_weighted( denominator_limit*10 ) 1407 else : 1408 p_qs[ max_pq_i ] -= diff 1409 1410 for i, t_idx in enumerate( t_list ) : 1411 new_transitions.append( 1412 Transition( 1413 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1414 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1415 prob=float(p_qs[ i ]), 1416 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1417 pq=p_qs[ i ] 1418 ) 1419 ) 1420 1421 self.set_transitions( new_transitions ) 1422 self.is_q_weighted = True 1423 1424 #-------------------------------------------------------# 1425 # Data Generation # 1426 #-------------------------------------------------------# 1427 1428 def generate_data( 1429 self, 1430 file_prefix: str, 1431 n_gen: int, 1432 include_states: bool, 1433 isomorphic_shifts : set[int]=None, 1434 random_seed : int=42 ) -> dict[any] : 1435 1436 trs = [ [] for _ in range( len( self.states ) ) ] 1437 for tr in self.transitions : 1438 trs[ tr.origin_state_idx ].append( ( 1439 tr.symbol_idx, 1440 float( tr.prob ), 1441 tr.target_state_idx ) ) 1442 1443 data = am_fast.generate_data( 1444 n_gen=n_gen, 1445 start_state=self.start_state, 1446 transitions=trs, 1447 alphabet=sorted(list(self.alphabet)), 1448 include_states=include_states, 1449 random_seed=random_seed 1450 ) 1451 1452 if isomorphic_shifts is not None : 1453 1454 if not include_states : 1455 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1456 1457 data[ "isomorphic_shifts" ] = {} 1458 1459 for shift in isomorphic_shifts : 1460 1461 shifted = HMM.isomorphic_shift( 1462 symbol_indices=data[ "symbol_index" ], 1463 state_indices=data[ "state_index" ], 1464 m=self, 1465 shift=shift 1466 ) 1467 1468 data[ "isomorphic_shifts" ][ shift ] = { 1469 "symbol_index" : shifted[ "symbol_index" ], 1470 "state_index" : shifted[ "state_index" ] 1471 } 1472 1473 am_fast.save_data( 1474 data=data, 1475 file_prefix=file_prefix, 1476 alphabet=sorted(list(self.alphabet)), 1477 n_states=len( self.states ), 1478 start_state=self.start_state, 1479 random_seed=random_seed, 1480 metadata=self.get_metadata() ) 1481 1482 return data 1483 1484 #-------------------------------------------------------# 1485 # Basic Visualization # 1486 #-------------------------------------------------------# 1487 1488 def draw_graph( 1489 self, 1490 engine : str = 'dot', 1491 output_dir : Path = Path('.'), 1492 show : bool = True 1493 ) -> None : 1494 1495 G = self.as_digraph() 1496 1497 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1498 1499 am_vis.draw_graph( 1500 self, 1501 output_dir=output_dir, 1502 title="am_graph", 1503 view=show, 1504 subgraphs=subgraphs, 1505 engine=engine ) 1506 1507 #-------------------------------------------------------# 1508 # Alternate Forms # 1509 #-------------------------------------------------------# 1510 1511 1512 def as_digraph( self ) : 1513 1514 """ 1515 Build nx DiGraph based on the HMM, without edge probabilities or symbols 1516 """ 1517 1518 G = nx.DiGraph() 1519 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1520 1521 for tr in self.transitions : 1522 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1523 1524 return G 1525 1526 def as_dfa( self, with_probs : bool ) : 1527 1528 precision=8 1529 1530 def edge_label( symb, prob ) : 1531 return f"({symb},{round(prob, precision)})" 1532 1533 # Build states, symbols, and transitions 1534 dfa_states = { i for i, _ in enumerate( self.states ) } 1535 1536 if not with_probs : 1537 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1538 else : 1539 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1540 1541 dfa_transitions = defaultdict(dict) 1542 1543 if not with_probs : 1544 for t in self.transitions : 1545 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1546 else : 1547 for t in self.transitions : 1548 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1549 1550 # Construct the DFA 1551 return DFA( 1552 states=dfa_states, 1553 input_symbols=dfa_symbols, 1554 transitions=dfa_transitions, 1555 initial_state=self.start_state, 1556 allow_partial=True, 1557 final_states={ 1558 s for s in dfa_states 1559 } 1560 )
Helper class that provides a standard way to create an ABC using inheritance.
98 def __init__( 99 self, 100 states : list[CausalState] | None = None, 101 transitions : list[Transition] | None = None, 102 start_state : int = 0, 103 alphabet : list[str] | None = None, 104 isoclasses : dict[str,set[str]] | None = None, 105 name : str = "", 106 description : str = "" ) : 107 108 self.alphabet = alphabet or [] 109 self.states = states or [] 110 self.transitions = transitions or [] 111 112 self.set_alphabet( self.alphabet ) 113 self.set_states( self.states ) 114 self.set_transitions( self.transitions ) 115 116 self.start_state = start_state 117 self.isoclasses = isoclasses or {} 118 119 self.name = name 120 self.description = description 121 122 # to be depreciated 123 self.isoclass = None 124 125 # --- derived -------- 126 127 self.complexity : dict[str, any] = {} 128 129 self.symbol_idx_map : dict[str,int] = {} 130 self.state_idx_map : dict[str,int] = {} 131 132 self.pi_fractional = None 133 self.pi = None 134 self.T = None 135 self.msp : MSP = None 136 self.reverse_am : HMM = None 137 self.is_q_weighted = False 138 self.is_minimal = False 139 140 # --- const ---------- 141 142 self.EPS = 1e-12
40 @staticmethod 41 def isomorphic_shift( 42 input_symbol_indices: np.ndarray, 43 input_state_indices: np.ndarray, 44 m : HMM, 45 shift : int = 1 46 ) -> dict[str, np.ndarray]: 47 48 if not any( state.isomorphs for state in m.states ): 49 raise ValueError("HMM has no states with isomorphs") 50 51 inputs = np.asarray(input_symbol_indices) 52 states = np.asarray(input_state_indices) 53 54 n_states = len(m.states) 55 56 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 57 58 for tr in m.transitions: 59 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 60 61 # Build isomorph remapping: identity by default, overridden where isomorphs exist 62 iso_table = np.arange(n_states, dtype=np.int32) 63 64 for i, state in enumerate(m.states): 65 if state.isomorphs: 66 isomorph = sorted(state.isomorphs)[0] 67 iso_table[ i ] = m.state_idx_map[isomorph] 68 69 for i, state in enumerate(m.states): 70 if state.isomorphs: 71 72 # extend the isomorph list to include the identity 73 isormorphs_with_identity = sorted( [ i ] + [ m.state_idx_map[iso] for iso in state.isomorphs ] ) 74 75 # find the identity index 76 pos = isormorphs_with_identity.index( i ) 77 78 # cyclical shift 79 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 80 81 origins = states[:-1] 82 targets = states[1:] 83 84 out_origins = iso_table[origins] 85 out_targets = iso_table[targets] 86 87 inv_sym = tr_sym_table[ out_origins, out_targets ] 88 inv_sts = np.empty(states.size, dtype=states.dtype) 89 90 inv_sts[:-1] = out_origins 91 inv_sts[-1] = out_targets[-1] 92 93 return { 94 "symbol_index": inv_sym.astype(inputs.dtype), 95 "state_index": inv_sts, 96 }
174 def from_dict( self, config ) : 175 176 self.name = config.name 177 self.description = config.description 178 179 self.set_states( states=[ 180 CausalState( 181 name=state[ "name" ], 182 classes=set( state[ "classes" ] ) 183 ) 184 for state in config.states 185 ] ) 186 187 self.set_transitions( transitions=[ 188 Transition( 189 origin_state_idx=tr[ "origin_state_idx" ], 190 target_state_idx=tr[ "target_state_idx" ], 191 prob=tr[ "prob" ], 192 symbol_idx=tr[ "symbol_idx" ] 193 ) 194 for tr in config.transitions 195 ] ) 196 197 self.set_alphabet( alphabet=config.alphabet )
199 def save_config( 200 self, 201 path : Path, 202 with_complexity : bool = False, 203 with_block_convergence : bool = False, 204 with_structural_properties : bool = False, 205 with_causal_properties : bool = False ) : 206 207 config = self.to_dict() 208 209 if with_complexity : 210 211 complexity = self.get_complexities( 212 with_block_convergence=with_block_convergence, 213 with_reversal=False # not implemented 214 ) 215 216 config[ "complexity" ] = complexity 217 218 config[ "structural_properties" ] = { 219 "unifilar" : self.is_unifilar(), 220 "row_stochastic" : self.is_row_stochastic(), 221 "strongly_connected" : self.is_strongly_connected(), 222 "aperiodic" : self.is_aperiodic(), 223 "minimal" : self.is_minimal_as_dfa( topological_only=False ), 224 "topologically_minimal" : self.is_minimal_as_dfa( topological_only=True ), 225 "is_epsilon_machine" : self.is_epsilon_machine() 226 } 227 228 # Not implemented 229 # config[ "causal_properties" ] = { 230 # "k_X" : self.k_X(), 231 # "R" : self.R(), 232 # "causally_reversible" : self.causally_reversible(), 233 # "microscopically_reversible" : self.microscopically_reversible(), 234 # } 235 236 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 237 json.dump( config, f, ensure_ascii=False, indent=2, default=list )
273 def set_alphabet( self, alphabet : list[str] ) : 274 275 self._invalidate() 276 277 old_alphabet = self.alphabet.copy() 278 279 aSet = set() 280 aSet.update( alphabet ) 281 282 self.alphabet = sorted(list(aSet)) 283 284 self.symbol_idx_map = {} 285 for idx, symbol in enumerate( self.alphabet ) : 286 self.symbol_idx_map[ symbol ] = idx 287 288 for i, tr in enumerate( self.transitions ) : 289 symbol = old_alphabet[ tr.symbol_idx ] 290 self.transitions[ i ] = Transition( 291 origin_state_idx=tr.origin_state_idx, 292 target_state_idx=tr.target_state_idx, 293 prob=tr.prob, 294 symbol_idx=self.symbol_idx_map[ symbol ] 295 )
321 def get_complexities( 322 self, 323 with_block_convergence=False, 324 with_reversal=False ) : 325 326 directly_calculable = [ 327 self.C_mu, 328 self.h_mu, 329 self.H_1, 330 self.rho_mu 331 ] 332 333 requires_block_convergence = [ 334 self.E, 335 self.T_inf, 336 self.S, 337 self.Chi 338 ] 339 340 requires_reversal = [ 341 self.q_mu, 342 self.r_mu, 343 self.b_mu 344 ] 345 346 complexities = { m.__name__ : m() for m in directly_calculable } 347 348 if with_block_convergence : 349 350 complexities |= { m.__name__ : m() for m in requires_block_convergence } 351 352 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 353 if key in self.complexity : 354 complexities[ key ] = self.complexity[ key ] 355 356 if with_reversal : 357 complexities |= { m.__name__ : m() for m in requires_reversal } 358 359 return complexities
370 def get_transition_matrix(self) : 371 372 if self.T is not None : 373 return self.T 374 375 n_states = len( self.states ) 376 T = np.zeros((n_states, n_states)) 377 378 for tr in self.transitions : 379 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 380 381 self.T = T 382 383 return self.T
387 def get_T_X(self) : 388 389 if self.T_x is not None : 390 return self.T_x 391 392 n_states = len( self.states ) 393 n_symbols = len( self.alphabet ) 394 395 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 396 397 for tr in self.transitions : 398 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 399 400 self.T_x = T_x 401 return self.T_x
405 def get_msp_qw( 406 self, 407 exact_state_cap: int = 1000, 408 verbose: bool = True, 409 ): 410 if self.msp is not None: 411 return self.msp 412 413 try : 414 415 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 416 417 self.msp = compute_msp_exact( 418 T_x=self.get_Tx_fractional(), 419 pi=self.get_fractional_stationary_distribution(), 420 n_states=len(self.states), 421 alphabet=self.alphabet, 422 exact_state_cap=1000, 423 bool = True 424 ) 425 426 return self.msp 427 428 except RuntimeError as e : 429 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 430 431 return self.get_msp()
433 def get_msp( 434 self, 435 exact_state_cap: int = 175_000, 436 jsd_eps: float = 1e-7, 437 k_ann: int = 50, 438 verbose = True, 439 ) : 440 441 if self.msp is not None: 442 return self.msp 443 444 T_x = self.get_T_X() 445 pi = self.get_stationary_distribution() 446 447 T_stacked = np.stack(T_x) 448 n_symbols = len(self.alphabet) 449 n_input_states = T_stacked.shape[1] 450 451 print( "\nComputing Mixed State Presentation\n" ) 452 453 self.msp = compute_msp( 454 T_x=T_x, 455 pi=pi, 456 n_states=len(self.states), 457 alphabet=self.alphabet, 458 exact_state_cap=exact_state_cap, 459 verbose=verbose 460 ) 461 462 return self.msp
464 def get_reverse_am(self) : 465 466 if self.reverse_am is not None: 467 return self.reverse_am 468 469 pi = self.get_stationary_distribution() 470 self.reverse_am = copy.deepcopy(self) 471 472 new_transitions = [] 473 for tr in self.transitions: 474 i = tr.target_state_idx 475 j = tr.origin_state_idx 476 477 p_reversed = (pi[j] * tr.prob) / pi[i] 478 479 new_transitions.append( 480 Transition( 481 origin_state_idx=i, 482 target_state_idx=j, 483 prob=p_reversed, 484 symbol_idx=tr.symbol_idx 485 ) 486 ) 487 488 self.reverse_am.set_transitions(new_transitions) 489 490 if self.reverse_am.is_epsilon_machine(): 491 return self.reverse_am 492 493 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 494 495 self.reverse_am.set_states( rmsp.states ) 496 self.reverse_am.set_transitions( rmsp.transitions ) 497 self.reverse_am.msp = rmsp 498 self.reverse_am.start_state = 0 499 500 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 501 self.reverse_am.minimize() 502 503 return self.reverse_am
507 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 508 509 self.to_q_weighted() 510 511 n_states = len( self.states ) 512 n_symbols = len( self.alphabet ) 513 514 T_x = [] 515 516 for x in range( n_symbols ) : 517 T_x.append( [] ) 518 for i in range( n_states ) : 519 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 520 521 for tr in self.transitions : 522 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 523 524 return T_x
538 def get_fractional_stationary_distribution(self) : 539 540 T = self.get_T_sympy() 541 542 if self.pi_fractional is not None : 543 return self.pi_fractional 544 545 G = self.as_digraph() 546 547 if not nx.is_strongly_connected(G): 548 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 549 550 self.pi_fractional = solve_for_pi_fractional( T ) 551 552 return self.pi_fractional
554 def get_stationary_distribution(self): 555 556 if self.pi is not None : 557 return self.pi 558 559 G = self.as_digraph() 560 561 if not nx.is_strongly_connected(G): 562 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 563 564 T = self.get_transition_matrix() 565 return solve_for_pi( T )
571 def C_mu( self ) : 572 573 """ 574 c_mu (statistical complexity, or forecasting complexity) 575 576 Interpretations 577 - The amount of historical information a process stores. 578 - The amount of structure in a process. 579 """ 580 581 m = self.get_complexity_measure_if_exists( "C_mu" ) 582 583 if m is not None : 584 return m 585 586 pi = self.get_stationary_distribution() 587 588 h = 0 589 for i, pr in enumerate( pi ) : 590 591 if pr < self.EPS : 592 continue 593 594 h += -pr * np.log2( pr ) 595 596 self.set_complexity_measure( "C_mu", h ) 597 598 return h
c_mu (statistical complexity, or forecasting complexity)
Interpretations - The amount of historical information a process stores. - The amount of structure in a process.
602 def h_mu( self ) : 603 604 """ 605 h_mu (entropy rate) 606 607 Interpretation 608 - Intrinsic Randomness in the process. 609 """ 610 611 m = self.get_complexity_measure_if_exists( "h_mu" ) 612 613 if m is not None : 614 return m 615 616 T = self.get_transition_matrix() 617 pi = self.get_stationary_distribution() 618 619 n_states = pi.size 620 621 h = 0 622 for i, pr in enumerate( pi ) : 623 624 if pr < self.EPS : 625 continue 626 627 row_entropy = 0 628 for j in range( len( pi ) ) : 629 630 if T[ i, j ] < self.EPS : 631 continue 632 633 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 634 635 h += pr * row_entropy 636 637 self.set_complexity_measure( "h_mu", h ) 638 639 return h
h_mu (entropy rate)
Interpretation - Intrinsic Randomness in the process.
643 def H_1(self): 644 645 """ 646 H_1 (single-symbol entropy, or marginal entropy) 647 Interpretations: 648 - How uncertain you are about a single measurement with no context. 649 - The total information content per symbol before seeing any neighbors. 650 - The maximum possible entropy rate, achieved by a memoryless process. 651 """ 652 653 m = self.get_complexity_measure_if_exists("H_1") 654 if m is not None: 655 return m 656 657 pi = self.get_stationary_distribution() 658 T_X = self.get_T_X() # dict: symbol -> matrix 659 660 h = 0.0 661 for T_x in T_X: 662 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 663 p_sym = 0.0 664 for i, pr in enumerate(pi): 665 if pr < self.EPS: 666 continue 667 p_sym += pr * T_x[i, :].sum() 668 669 if p_sym < self.EPS: 670 continue 671 h -= p_sym * np.log2(p_sym) 672 673 self.set_complexity_measure("H_1", h) 674 return h
H_1 (single-symbol entropy, or marginal entropy) Interpretations: - How uncertain you are about a single measurement with no context. - The total information content per symbol before seeing any neighbors. - The maximum possible entropy rate, achieved by a memoryless process.
678 def rho_mu(self) : 679 680 """ 681 rho_mu (anticipated information, or single-symbol redundancy) 682 Interpretations: 683 - How much the past reduces your uncertainty about the next symbol. 684 - The portion of each symbol's information that was already predictable. 685 - The difference between naive per-symbol uncertainty and true randomness. 686 - Equal to H_1 - h_mu: the gap closed by knowing context. 687 """ 688 689 m = self.get_complexity_measure_if_exists("rho_mu") 690 691 if m is not None: 692 return m 693 694 rho = self.H_1() - self.h_mu() 695 696 self.set_complexity_measure("rho_mu", rho) 697 698 return rho
rho_mu (anticipated information, or single-symbol redundancy) Interpretations: - How much the past reduces your uncertainty about the next symbol. - The portion of each symbol's information that was already predictable. - The difference between naive per-symbol uncertainty and true randomness. - Equal to H_1 - h_mu: the gap closed by knowing context.
702 def block_convergence( self ) : 703 704 """ 705 block_convergence 706 - Estimates a range of complexity measures using block entropy convergence. 707 - E, S, T, T_L, H_L, h_mu_L, H_sync 708 """ 709 710 trs = [ [] for _ in range( len( self.states ) ) ] 711 for tr in self.transitions : 712 trs[ tr.origin_state_idx ].append( ( 713 tr.symbol_idx, 714 float( tr.prob ), 715 tr.target_state_idx ) ) 716 717 pi = self.get_stationary_distribution() 718 719 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 720 branches = [(1.0, list(state_dist))] 721 722 print( "\nComputing Block Entropy\n" ) 723 724 C = am_fast.block_entropy_convergence( 725 h_mu = self.h_mu(), 726 n_states = len( self.states ), 727 n_symbols = len( self.alphabet ), 728 convergence_tol = 1e-6, 729 precision = 10, 730 eps = 1e-25, 731 branches = branches, 732 trans = trs, 733 max_branches = 30_000_000 734 ) 735 736 print( "Done\n" ) 737 738 self.set_complexity_measure( f"E", C.E ) 739 self.set_complexity_measure( f"S", C.S ) 740 self.set_complexity_measure( f"T_inf", C.T ) 741 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 742 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 743 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 744 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 745 746 return C
block_convergence - Estimates a range of complexity measures using block entropy convergence. - E, S, T, T_L, H_L, h_mu_L, H_sync
750 def E( self ) : 751 752 """ 753 E (excess entropy) 754 755 Interpretations: 756 - Mutual information from past to future I(Past;Future). 757 - How much total information from the past reduces uncertainty in the future 758 - Amount of (apparent) stored information 759 760 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 761 https://arxiv.org/abs/1309.3792 762 """ 763 764 m = self.get_complexity_measure_if_exists( "E" ) 765 766 if m is not None : 767 return m 768 769 try : 770 msp = self.get_msp() 771 E, S, T = msp.get_E_S_T() 772 self.set_complexity_measure( "E", E ) 773 self.set_complexity_measure( "S", S ) 774 self.set_complexity_measure( "T_inf", T ) 775 776 except Exception as e : 777 778 print( f"MSP failed {e}" ) 779 exit() 780 781 C = self.block_convergence() 782 E = C.E 783 self.set_complexity_measure( "E", E ) 784 785 return E
E (excess entropy)
Interpretations: - Mutual information from past to future I(Past;Future). - How much total information from the past reduces uncertainty in the future - Amount of (apparent) stored information
Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792
789 def S( self ) : 790 791 """ 792 S (synchronization information) 793 794 Interpretation 795 -synchronization information 796 797 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 798 https://arxiv.org/abs/1309.3792 799 """ 800 801 m = self.get_complexity_measure_if_exists( "S" ) 802 803 if m is not None : 804 return m 805 806 try : 807 msp = self.get_msp() 808 E, S, T = msp.get_E_S_T() 809 self.set_complexity_measure( "E", E ) 810 self.set_complexity_measure( "S", S ) 811 self.set_complexity_measure( "T_inf", T ) 812 813 except Exception as e : 814 print( f"{e} \nFalling back to iterative estimation.") 815 exit() 816 817 C = self.block_convergence() 818 S = C.S 819 self.set_complexity_measure( "S", S ) 820 821 return S
S (synchronization information)
Interpretation -synchronization information
Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792
825 def T_inf( self ) : 826 827 """ 828 T_inf (transient information) 829 830 Interpretation 831 -Transient information 832 833 Exact Complexity: The Spectral Decomposition of Intrinsic Computation 834 https://arxiv.org/abs/1309.3792 835 """ 836 837 m = self.get_complexity_measure_if_exists( "T_inf" ) 838 839 if m is not None : 840 return m 841 842 try : 843 msp = self.get_msp() 844 E, S, T = msp.get_E_S_T() 845 self.set_complexity_measure( "E", E ) 846 self.set_complexity_measure( "S", S ) 847 self.set_complexity_measure( "T_inf", T ) 848 849 except Exception as e : 850 print( f"{e} \nFalling back to iterative estimation.") 851 exit() 852 C = self.block_convergence() 853 T_inf = C.T 854 self.set_complexity_measure( "T_inf", T_inf ) 855 856 return T_inf
T_inf (transient information)
Interpretation -Transient information
Exact Complexity: The Spectral Decomposition of Intrinsic Computation https://arxiv.org/abs/1309.3792
860 def Chi( self ) : 861 862 """ 863 Chi (crypticity) 864 865 Interpretation 866 - Difference between internal stored information and apparent information to an observer 867 - How muching information is hiding in the system 868 """ 869 870 m = self.get_complexity_measure_if_exists( "Chi" ) 871 872 if m is not None : 873 return m 874 875 Chi = self.C_mu() - self.E() 876 877 self.set_complexity_measure( "Chi", Chi ) 878 879 return Chi
Chi (crypticity)
Interpretation - Difference between internal stored information and apparent information to an observer - How muching information is hiding in the system
883 def k_X( self ) : 884 885 """ 886 k_x (cryptic order) 887 888 Interpretation 889 890 """ 891 892 m = self.get_complexity_measure_if_exists( "k_X" ) 893 894 if m is not None : 895 return m 896 897 raise NotImplementedError("Not implemented")
k_x (cryptic order)
Interpretation
901 def R( self ) : 902 903 """ 904 R (Markov order) 905 Interpretations: 906 - The smallest L such that h_mu(L) = h_mu. 907 - The context length beyond which additional history gives no predictive benefit. 908 - The memory length of the process — how far back correlations extend. 909 - The ideal context window length for a predictor. 910 - R = 0 means the process is IID (memoryless). 911 - R = infinity means the process has unbounded memory (non-finitary). 912 — only report if convergence was clean, otherwise None 913 """ 914 915 m = self.get_complexity_measure_if_exists( "R" ) 916 917 if m is not None : 918 return m 919 920 raise NotImplementedError("Not implemented")
R (Markov order) Interpretations: - The smallest L such that h_mu(L) = h_mu. - The context length beyond which additional history gives no predictive benefit. - The memory length of the process — how far back correlations extend. - The ideal context window length for a predictor. - R = 0 means the process is IID (memoryless). - R = infinity means the process has unbounded memory (non-finitary). — only report if convergence was clean, otherwise None
924 def b_mu( self ) : 925 926 """ 927 b_mu - requires reverse HMM 928 929 930 Interpretation 931 932 """ 933 934 m = self.get_complexity_measure_if_exists( "b_mu" ) 935 936 if m is not None : 937 return m 938 939 raise NotImplementedError("Not implemented")
b_mu - requires reverse HMM
Interpretation
943 def r_mu( self ) : 944 945 """ 946 r_mu - requires reverse HMM 947 948 949 Interpretation 950 951 """ 952 953 m = self.get_complexity_measure_if_exists( "r_mu" ) 954 955 if m is not None : 956 return m 957 958 raise NotImplementedError("Not implemented")
r_mu - requires reverse HMM
Interpretation
962 def q_mu( self ) : 963 964 """ 965 q_mu 966 967 Interpretation 968 """ 969 970 m = self.get_complexity_measure_if_exists( "q_mu" ) 971 972 if m is not None : 973 return m 974 975 raise NotImplementedError("Not implemented")
q_mu
Interpretation
981 def causally_reversible( self ) : 982 983 """ 984 causally_reversible 985 """ 986 987 m = self.get_complexity_measure_if_exists( "causally_reversible" ) 988 989 if m is not None : 990 return m 991 992 raise NotImplementedError("Not implemented")
causally_reversible
996 def microscopically_reversible( self ) : 997 998 """ 999 microscopically_reversible 1000 """ 1001 1002 m = self.get_complexity_measure_if_exists( "microscopically_reversible" ) 1003 1004 if m is not None : 1005 return m 1006 1007 raise NotImplementedError("Not implemented")
microscopically_reversible
1011 def is_row_stochastic(self) : 1012 1013 """ 1014 is_row_stochastic 1015 -All states have outgoing transition probabilities which sum 1 1016 """ 1017 1018 sums = np.zeros( len( self.states ) ) 1019 for tr in self.transitions : 1020 sums[ tr.origin_state_idx ] += tr.prob 1021 return np.allclose( sums, 1.0 )
is_row_stochastic -All states have outgoing transition probabilities which sum 1
1025 def is_unifilar(self) : 1026 1027 """ 1028 is_unifilar 1029 -All states have no more than one outgoing transition per-symbol 1030 """ 1031 1032 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 1033 1034 for tr in self.transitions : 1035 1036 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 1037 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 1038 1039 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 1040 return False 1041 1042 return True
is_unifilar -All states have no more than one outgoing transition per-symbol
1046 def is_strongly_connected(self) : 1047 1048 """ 1049 is_strongly_connected 1050 -Every state is reachable from every other state 1051 """ 1052 1053 return nx.is_strongly_connected( self.as_digraph() )
is_strongly_connected -Every state is reachable from every other state
1057 def is_aperiodic(self) : 1058 1059 """ 1060 is_aperiodic 1061 -"A strongly connected directed graph is aperiodic 1062 if there is no integer k > 1 that divides the length of every cycle in the graph." 1063 """ 1064 1065 return nx.is_aperiodic( self.as_digraph() )
is_aperiodic -"A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."
1069 def is_minimal_as_dfa( self, topological_only : bool, verbose=True ) : 1070 1071 with_probs = not topological_only 1072 1073 # Construct the DFA 1074 dfa = self.as_dfa( with_probs=with_probs ) 1075 1076 # Minimize the DFA 1077 #dfa = dfa.minify(retain_names=True) 1078 dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1079 1080 # check we have minimal number of states 1081 if len( dfa.states ) != len( self.states ) : 1082 if verbose : 1083 print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" ) 1084 return False 1085 1086 return True
1088 def is_topological_epsilon_machine( self, verbose=True ) : 1089 1090 """ 1091 is_topological_epsilon_machine 1092 -see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024 1093 """ 1094 1095 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1096 if verbose : 1097 print( f"Either non unifilar or not strongly connected" ) 1098 return False 1099 else : 1100 return self.is_minimal_as_dfa( topological_only=True, verbose=verbose )
is_topological_epsilon_machine -see algorithm 2 in Enumerating Finitary Processes, Johnson et. al, 2024
1102 def is_epsilon_machine( self, verbose=True ) : 1103 1104 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1105 if verbose : 1106 print( f"Either non unifilar or not strongly connected" ) 1107 return False 1108 else : 1109 return self.is_minimal_as_dfa( topological_only=False, verbose=verbose )
1115 def minimize(self, retain_names: bool = True): 1116 1117 """ 1118 minimize 1119 -Simplify to minimal form, using Myhill Nerode 1120 """ 1121 1122 if self.is_minimal : 1123 return 1124 1125 start = time.perf_counter() 1126 1127 if not self.is_unifilar(): 1128 raise ValueError( 1129 "DFA minimization is not valid for non-unifilar HMMs" 1130 ) 1131 1132 was_strongly_connected = self.is_strongly_connected() 1133 1134 was_row_stochastic = self.is_row_stochastic() 1135 n_states_before = len(self.states) 1136 1137 dfa = self.as_dfa(with_probs=True) 1138 1139 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1140 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1141 1142 # Build lookup from original state index -> CausalState object 1143 orig_state = {i: s for i, s in enumerate(self.states)} 1144 eq_list = list(min_dfa.states) 1145 1146 start_eq = min_dfa.initial_state 1147 1148 # Separate the start state, then sort the rest by the 1149 # smallest original state index inside each equivalence class. 1150 other_eqs = [eq for eq in eq_list if eq != start_eq] 1151 other_eqs.sort(key=lambda eq: min(eq)) 1152 1153 # Recombine so start eq comes first, followed by the sorted remaining classes 1154 eq_list = [start_eq] + other_eqs 1155 # ---------------------------------------------------------- 1156 1157 # Recompute eq_to_idx with the new ordering 1158 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1159 1160 # new_start is now guaranteed to be 0 1161 new_start = 0 1162 1163 # Map each original state index -> its equivalence class 1164 # Guard: minify() silently drops unreachable states 1165 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1166 1167 # Build lookup from original state index -> its transitions 1168 orig_trs = defaultdict(list) 1169 for t in self.transitions: 1170 orig_trs[t.origin_state_idx].append(t) 1171 1172 new_trs = [] 1173 for eq in min_dfa.states: 1174 rep = next(iter(eq)) 1175 origin_idx = eq_to_idx[eq] 1176 for t in orig_trs[rep]: 1177 target_eq = orig_to_eq[t.target_state_idx] 1178 target_idx = eq_to_idx[target_eq] 1179 new_trs.append(Transition( 1180 origin_state_idx = origin_idx, 1181 target_state_idx = target_idx, 1182 prob = t.prob, 1183 symbol_idx = t.symbol_idx, 1184 )) 1185 1186 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1187 1188 # Compute new names 1189 if retain_names: 1190 new_names = [ 1191 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1192 else members[0].name 1193 for members in members_list 1194 ] 1195 else: 1196 new_names = [str(j) for j in range(len(eq_list))] 1197 1198 old_name_to_new_name = { 1199 m.name: new_names[j] 1200 for j, members in enumerate(members_list) 1201 for m in members 1202 } 1203 1204 # Build the new states, preserving classes and isomorphs regardless of naming 1205 new_states = [] 1206 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1207 classes = set().union(*(m.classes for m in members)) 1208 isomorphs = { 1209 old_name_to_new_name.get(iso, iso) 1210 for m in members 1211 for iso in m.isomorphs 1212 if old_name_to_new_name.get(iso, iso) != name 1213 } 1214 new_states.append(CausalState( 1215 name = name, 1216 classes = classes, 1217 isomorphs = isomorphs, 1218 )) 1219 1220 self.set_states(new_states) 1221 self.set_transitions(new_trs) 1222 self.start_state = new_start 1223 1224 if n_states_before == len(new_states) : 1225 print( f"{n_states_before} state HMM was already minimal." ) 1226 else : 1227 print( f"Minimized from {n_states_before} to {len(new_states)}" ) 1228 1229 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1230 raise RuntimeError( 1231 f"Minimization broke strongly connected" 1232 ) 1233 1234 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1235 raise RuntimeError( 1236 f"Minimization broke row stochasticity" 1237 ) 1238 1239 self.is_minimal = True
minimize -Simplify to minimal form, using Myhill Nerode
1246 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1247 1248 # get equivalent networkx graph 1249 G = self.as_digraph() 1250 1251 # if already strongly connected, nothing to do 1252 if not nx.is_strongly_connected( G ) : 1253 1254 start = time.perf_counter() 1255 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1256 1257 # decompose into strongly connected components and sort by length 1258 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1259 subgraph_nodes.sort(key=len) 1260 component_state_set = subgraph_nodes[-1] 1261 1262 # Take the largest strongly connected component (as list of state names) 1263 component_states = sorted( list( component_state_set ) ) 1264 1265 # make temporary copies of the old transitions and states 1266 old_transitions = [ 1267 Transition( 1268 origin_state_idx=tr.origin_state_idx, 1269 target_state_idx=tr.target_state_idx, 1270 prob=tr.prob, 1271 symbol_idx=tr.symbol_idx 1272 ) 1273 1274 for tr in self.transitions 1275 ] 1276 1277 old_states = [ 1278 CausalState( 1279 name=s.name, 1280 classes=s.classes, 1281 isomorphs=s.isomorphs 1282 ) 1283 for s in self.states 1284 ] 1285 1286 self.set_states( 1287 states=[ 1288 state 1289 for i, state in enumerate( old_states ) if i in component_state_set 1290 ] 1291 ) 1292 1293 # we will build new transition list based on those belonging to the component 1294 self.set_transitions( transitions= [] ) 1295 1296 # for tracking which new transitions leave each state 1297 transitions_from_state = { state : set() for state in component_states } 1298 new_transitions = [] 1299 1300 for tr in old_transitions : 1301 1302 origin_state_name = old_states[ tr.origin_state_idx ].name 1303 target_state_name = old_states[ tr.target_state_idx ].name 1304 1305 # skip transitions that connect separate strongly connected components 1306 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1307 continue 1308 1309 # track transitions (by index in new transitions list) that leave this state 1310 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1311 1312 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1313 my_target_state_idx = self.state_idx_map[ target_state_name ] 1314 1315 new_transitions.append( 1316 Transition( 1317 origin_state_idx=my_origin_state_idx, 1318 target_state_idx=my_target_state_idx, 1319 prob=tr.prob, 1320 symbol_idx=tr.symbol_idx 1321 ) ) 1322 1323 self.set_transitions( transitions=new_transitions ) 1324 1325 # if we removed an outgoing transition from a state, we need to distribute its probability 1326 # among the remaining outgoing transitions from the state 1327 for state in component_states : 1328 1329 # get the set of transitions leaving this state 1330 state_trs = transitions_from_state[ state ] 1331 1332 # sum the probabilities of the outgoing transitions from the state 1333 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1334 1335 # how much probability is missing 1336 diff = 1.0 - p_sum 1337 1338 # if significant difference 1339 if abs( diff ) > self.EPS : 1340 1341 # calculate how much of the difference each transition gets 1342 adjustment = diff / len( state_trs ) 1343 1344 # update the transitions 1345 for i in state_trs : 1346 1347 # transition with origional probability 1348 tr = self.transitions[ i ] 1349 1350 # adjusted probability 1351 self.transitions[ i ] = Transition( 1352 origin_state_idx=tr.origin_state_idx, 1353 target_state_idx=tr.target_state_idx, 1354 prob=tr.prob + adjustment, 1355 symbol_idx=tr.symbol_idx ) 1356 1357 if rename_states : 1358 self.set_states( [ 1359 CausalState( name=f"{i}" ) 1360 for i, s in enumerate( self.states ) 1361 ] )
1364 def to_q_weighted( self, denominator_limit=1000 ) : 1365 1366 """ 1367 to_q_weighted 1368 -Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq 1369 """ 1370 1371 if self.is_q_weighted : 1372 return 1373 1374 if not self.is_row_stochastic() : 1375 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1376 1377 t_from = [[] for _ in range(len(self.states))] 1378 1379 for i, tr in enumerate( self.transitions ) : 1380 t_from[ tr.origin_state_idx ].append( i ) 1381 1382 new_transitions = [] 1383 for t_list in t_from : 1384 1385 if not t_list : 1386 continue 1387 1388 p_q_sum = Fraction(0,1) 1389 p_qs = [] 1390 1391 for t_idx in t_list : 1392 1393 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1394 p_q_sum += p_q 1395 p_qs.append( p_q ) 1396 1397 if p_q_sum != Fraction(1,1) : 1398 1399 max_pq_i = np.argmax( p_qs ) 1400 max_oq = p_qs[ max_pq_i ] 1401 1402 diff = p_q_sum - Fraction(1,1) 1403 1404 # recurse with higher resolution 1405 if diff > max_oq : 1406 return self.to_q_weighted( denominator_limit*10 ) 1407 else : 1408 p_qs[ max_pq_i ] -= diff 1409 1410 for i, t_idx in enumerate( t_list ) : 1411 new_transitions.append( 1412 Transition( 1413 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1414 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1415 prob=float(p_qs[ i ]), 1416 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1417 pq=p_qs[ i ] 1418 ) 1419 ) 1420 1421 self.set_transitions( new_transitions ) 1422 self.is_q_weighted = True
to_q_weighted -Snaps edge probabilities to exact Fractions, and assigns Fraction values to Transition.pq
1428 def generate_data( 1429 self, 1430 file_prefix: str, 1431 n_gen: int, 1432 include_states: bool, 1433 isomorphic_shifts : set[int]=None, 1434 random_seed : int=42 ) -> dict[any] : 1435 1436 trs = [ [] for _ in range( len( self.states ) ) ] 1437 for tr in self.transitions : 1438 trs[ tr.origin_state_idx ].append( ( 1439 tr.symbol_idx, 1440 float( tr.prob ), 1441 tr.target_state_idx ) ) 1442 1443 data = am_fast.generate_data( 1444 n_gen=n_gen, 1445 start_state=self.start_state, 1446 transitions=trs, 1447 alphabet=sorted(list(self.alphabet)), 1448 include_states=include_states, 1449 random_seed=random_seed 1450 ) 1451 1452 if isomorphic_shifts is not None : 1453 1454 if not include_states : 1455 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1456 1457 data[ "isomorphic_shifts" ] = {} 1458 1459 for shift in isomorphic_shifts : 1460 1461 shifted = HMM.isomorphic_shift( 1462 symbol_indices=data[ "symbol_index" ], 1463 state_indices=data[ "state_index" ], 1464 m=self, 1465 shift=shift 1466 ) 1467 1468 data[ "isomorphic_shifts" ][ shift ] = { 1469 "symbol_index" : shifted[ "symbol_index" ], 1470 "state_index" : shifted[ "state_index" ] 1471 } 1472 1473 am_fast.save_data( 1474 data=data, 1475 file_prefix=file_prefix, 1476 alphabet=sorted(list(self.alphabet)), 1477 n_states=len( self.states ), 1478 start_state=self.start_state, 1479 random_seed=random_seed, 1480 metadata=self.get_metadata() ) 1481 1482 return data
1488 def draw_graph( 1489 self, 1490 engine : str = 'dot', 1491 output_dir : Path = Path('.'), 1492 show : bool = True 1493 ) -> None : 1494 1495 G = self.as_digraph() 1496 1497 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1498 1499 am_vis.draw_graph( 1500 self, 1501 output_dir=output_dir, 1502 title="am_graph", 1503 view=show, 1504 subgraphs=subgraphs, 1505 engine=engine )
1512 def as_digraph( self ) : 1513 1514 """ 1515 Build nx DiGraph based on the HMM, without edge probabilities or symbols 1516 """ 1517 1518 G = nx.DiGraph() 1519 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1520 1521 for tr in self.transitions : 1522 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1523 1524 return G
Build nx DiGraph based on the HMM, without edge probabilities or symbols
1526 def as_dfa( self, with_probs : bool ) : 1527 1528 precision=8 1529 1530 def edge_label( symb, prob ) : 1531 return f"({symb},{round(prob, precision)})" 1532 1533 # Build states, symbols, and transitions 1534 dfa_states = { i for i, _ in enumerate( self.states ) } 1535 1536 if not with_probs : 1537 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1538 else : 1539 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1540 1541 dfa_transitions = defaultdict(dict) 1542 1543 if not with_probs : 1544 for t in self.transitions : 1545 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1546 else : 1547 for t in self.transitions : 1548 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1549 1550 # Construct the DFA 1551 return DFA( 1552 states=dfa_states, 1553 input_symbols=dfa_symbols, 1554 transitions=dfa_transitions, 1555 initial_state=self.start_state, 1556 allow_partial=True, 1557 final_states={ 1558 s for s in dfa_states 1559 } 1560 )