amachine.am_create

  1from collections import defaultdict
  2import copy
  3import random
  4
  5from .am_hmm          import HMM
  6from .am_causal_state import CausalState
  7from .am_transition   import Transition
  8
  9from .am_random import uniform_dist, exp_uniform_blend
 10
 11def star_join(
 12	exit_symbol : str, 
 13	enter_symbols : list[str],
 14	machines : list[HMM],
 15	mode_residency_factor : float ) -> HMM :
 16
 17	machine = HMM()
 18
 19	isomorphic_groups = defaultdict(list)
 20	for i, m in enumerate( machines ) :
 21		if m.isoclass :
 22			isomorphic_groups[ m.isoclass ].append( i )
 23
 24	# since we are merging multuple machines which might have name collision
 25	# we need to rename the states to ensure uniqueness
 26	def rename_state( base_name : str, g : int ) :
 27		return f"{g}/{base_name}"
 28
 29	def get_gid( idx : int, isoclass : int | None ) :
 30		return idx if isoclass is None else isoclass
 31
 32	machine.set_alphabet( [ exit_symbol ] )
 33
 34	for m in machines :
 35		machine.extend_alphabet( alphabet=m.alphabet )
 36
 37	# create a connector state and connector state class
 38
 39	connector_state = CausalState(
 40		name=f"/c", 
 41		classes=set({"connector"}) 
 42	)
 43	
 44	# initial states before adding each machines states
 45	machine.set_states( [ connector_state ] )
 46	machine.start_state = 0
 47
 48	# number of machines (groups of states)
 49	n_groups = len( machines )
 50
 51	# make sure we have enough symbols (otherwise connector can't be unifilar)
 52	if n_groups > len(enter_symbols) :
 53		raise Exception(
 54			f"Too few enter symbols given number of machines"
 55		)
 56
 57	# for each given machine
 58	for m_idx, m in enumerate( machines ) : 
 59
 60		# default to the index of the machine in the list
 61		m_gid = get_gid( m_idx, m.isoclass )
 62
 63		# give the states from this machine a class name
 64		m_classes = { 
 65			f"m_{m_idx}", 
 66			f"isoclass_{m.isoclass}" 
 67		}
 68
 69		added_states = []
 70		
 71		# create a state and extend our existing machine to include it
 72		for s_idx, state in enumerate( m.states ) :
 73
 74			isomorphs=set()
 75			if m.isoclass is not None and m.isoclass in isomorphic_groups :
 76				for other_idx in isomorphic_groups[ m.isoclass ] :
 77					
 78					if other_idx == m_idx : 
 79						continue
 80					
 81					other_m = machines[ other_idx ]
 82					isomorphs.add(  
 83						rename_state( 
 84							other_m.states[ s_idx ].name, 
 85							get_gid( other_idx, other_m.isoclass ) )
 86					)
 87
 88			added_states.append( 
 89				CausalState( 
 90					name=rename_state(state.name, m_gid),
 91					classes=( m_classes | state.classes ),
 92					isomorphs=isomorphs
 93				) 
 94			)
 95
 96		machine.extend_states( added_states )
 97
 98		added_transitions = []
 99
100		# add all of the transitions from the machine
101		for tr in m.transitions :
102
103			# get the names of the states for the transition
104			origin_state_name = rename_state( m.states[ tr.origin_state_idx ].name, m_gid )
105			target_state_name = rename_state( m.states[ tr.target_state_idx ].name, m_gid )
106
107			# idx of the symbol remaped to this machines alphabet list
108			new_symbol_idx = machine.symbol_idx_map[ m.alphabet[ tr.symbol_idx ] ]
109
110			# create and add the new transition
111			added_transitions.append( Transition(
112				origin_state_idx=machine.state_idx_map[ origin_state_name ],
113				target_state_idx=machine.state_idx_map[ target_state_name ],
114				prob=tr.prob,
115				symbol_idx=new_symbol_idx
116			) )
117
118		machine.extend_transitions( added_transitions )
119
120		# Add connector transitions, and adjust transition probabilities to sum to 1
121
122		# the name of the state that is the entry point to this group from the connector
123		m_entry_state_name = rename_state( m.states[ m.start_state ].name, m_gid )
124
125		# get the index of the entry state for this machine
126		m_entry_state_idx = machine.state_idx_map[ m_entry_state_name ]
127
128
129		# Get the within group transitions from m's entry state
130		# ( the probabilities will need to be adjusted )
131		transition_ids_from_m_entry = set()
132		for i, tr in enumerate( machine.transitions ) : 
133			if tr.origin_state_idx == m_entry_state_idx :
134				transition_ids_from_m_entry.add( i )
135		
136		n_from_entry = len( transition_ids_from_m_entry )
137
138		# Pr of staying in this group is distributed over the within group outgoing edges from the entry state 
139		for i in transition_ids_from_m_entry :
140
141			machine.transitions[ i ] = Transition(
142				origin_state_idx=machine.transitions[ i ].origin_state_idx,
143				target_state_idx=machine.transitions[ i ].target_state_idx,
144				prob=mode_residency_factor / n_from_entry,
145				symbol_idx=machine.transitions[ i ].symbol_idx
146			)
147
148		# from m's entry state back to connector
149		escape_pr = 1.0 - mode_residency_factor
150
151		machine.extend_transitions( transitions=[
152			Transition(
153				origin_state_idx=m_entry_state_idx,
154				target_state_idx=machine.start_state,
155				prob=escape_pr,
156				symbol_idx=machine.symbol_idx_map[ exit_symbol ]
157			)
158		] )
159
160		# from the connector to m's entry state
161		machine.extend_alphabet( alphabet=[ enter_symbols[ m_idx ] ] )
162		
163		machine.extend_transitions( transitions=[
164			Transition(
165				origin_state_idx=machine.start_state,
166				target_state_idx=m_entry_state_idx,
167				prob=( 1.0 / n_groups ),
168				symbol_idx=machine.symbol_idx_map[ enter_symbols[ m_idx ] ]
169			)
170		] )
171
172	return machine
173
174
175def star(
176	exit_symbol          : str,
177	enter_symbols        : list[str],
178	normal_symbols       : list[str],
179	n_modes              : int = 7,
180	n_isomorphic         : int = 2,
181	randomness           : float = 0.3,
182	connectedness        : float = 0.5,
183	residency_factor     : float = 0.5,
184	n_normal_symbols     : int = 4,
185	t_states_per_machine : int = 17 )  -> HMM :
186
187	if len( normal_symbols ) < n_normal_symbols*n_isomorphic :
188		raise ValueError( "Must have at least n_normal_symbols*n_isomorphic normal symbols" )
189
190	if len( enter_symbols ) < n_modes*n_isomorphic :
191		raise ValueError( "Must have at least n_modes*n_isomorphic enter symbols" )
192
193	alphabet     = [ f"{normal_symbols[i]}" for i in range( 0, n_normal_symbols            ) ]
194	iso_alphabet = [ f"{normal_symbols[i]}" for i in range( n_normal_symbols, n_normal_symbols*2  ) ]
195
196	random_machines = []
197
198	for i in range( n_modes ) :
199
200		m = random_machine( 
201			n_states=t_states_per_machine, 
202			symbols=alphabet, 
203			randomness=randomness,
204			connectedness=connectedness ) 
205		
206		m.collapse_to_largest_strongly_connected_subgraph()
207		m_iso = isomorphic_to( m, alphabet=iso_alphabet )
208
209		m.isoclass     = f"{i}"
210		m_iso.isoclass = f"{i}"
211
212		for j, state in enumerate( m.states ) : 
213
214			m.states[ j ].add_isomorph( m_iso.states[ j ].name )
215			m_iso.states[ j ].add_isomorph( m.states[ j ].name )
216
217		random_machines.append( m )
218		random_machines.append( m_iso )
219
220	mode_machine = star_join(  
221		exit_symbol=exit_symbol, 
222		enter_symbols=enter_symbols,
223		machines=random_machines,
224		mode_residency_factor=residency_factor
225	)
226
227	return mode_machine
228
229
230def isomorphic_to( 
231	m : HMM, 
232	alphabet : list[str],
233	decorator : str = '@' ) -> HMM :
234
235	# make sure there are enough symbols
236	if len( alphabet ) < len( m.alphabet ) :
237		raise ValueError( "Not enough symbols in the alphabet" )
238
239	# take the as much of them as needed
240	alphabet_used = alphabet[ 0 : len( m.alphabet ) ]
241
242	states = [ 
243		CausalState( 
244			name=f"{s.name}{decorator}",
245			classes=copy.deepcopy( s.classes )
246		)
247		for s in m.states
248	] 
249
250	return HMM(
251		states=states,
252		transitions=copy.deepcopy( m.transitions ),
253		start_state=0,
254		alphabet=alphabet_used
255	)
256
257def random_machine( 
258	n_states : int, 
259	symbols  : list[str], 
260	connectedness,
261	randomness )  -> HMM  :
262
263	states=[
264		CausalState( name=f"{i}" )
265		for i in range( n_states  )
266	] 
267
268	n_symbols = len( symbols )
269	transitions = []
270
271	for state_idx, state in enumerate( states ) : 
272
273		n_transitions = sum( random.random() < connectedness for _ in range( n_symbols - 1 ) ) + 1
274		transition_to = random.sample( range( n_states ), n_transitions )
275
276		transition_probabilities = exp_uniform_blend( n=n_transitions, alpha=randomness )
277		transition_symbols_indices = random.sample( range( n_symbols ), n_transitions )
278
279		for i, p in enumerate( transition_probabilities ) :
280			transitions.append(
281				Transition(
282					origin_state_idx=state_idx,
283					target_state_idx=transition_to[ i ],
284					prob=p,
285					symbol_idx=transition_symbols_indices[ i ]
286				)
287			) 
288
289	return HMM( 
290		states=states,
291		transitions=transitions,
292		start_state=0,
293		alphabet=symbols.copy()
294	)
def star_join( exit_symbol: str, enter_symbols: list[str], machines: list[amachine.am_hmm.HMM], mode_residency_factor: float) -> amachine.am_hmm.HMM:
 12def star_join(
 13	exit_symbol : str, 
 14	enter_symbols : list[str],
 15	machines : list[HMM],
 16	mode_residency_factor : float ) -> HMM :
 17
 18	machine = HMM()
 19
 20	isomorphic_groups = defaultdict(list)
 21	for i, m in enumerate( machines ) :
 22		if m.isoclass :
 23			isomorphic_groups[ m.isoclass ].append( i )
 24
 25	# since we are merging multuple machines which might have name collision
 26	# we need to rename the states to ensure uniqueness
 27	def rename_state( base_name : str, g : int ) :
 28		return f"{g}/{base_name}"
 29
 30	def get_gid( idx : int, isoclass : int | None ) :
 31		return idx if isoclass is None else isoclass
 32
 33	machine.set_alphabet( [ exit_symbol ] )
 34
 35	for m in machines :
 36		machine.extend_alphabet( alphabet=m.alphabet )
 37
 38	# create a connector state and connector state class
 39
 40	connector_state = CausalState(
 41		name=f"/c", 
 42		classes=set({"connector"}) 
 43	)
 44	
 45	# initial states before adding each machines states
 46	machine.set_states( [ connector_state ] )
 47	machine.start_state = 0
 48
 49	# number of machines (groups of states)
 50	n_groups = len( machines )
 51
 52	# make sure we have enough symbols (otherwise connector can't be unifilar)
 53	if n_groups > len(enter_symbols) :
 54		raise Exception(
 55			f"Too few enter symbols given number of machines"
 56		)
 57
 58	# for each given machine
 59	for m_idx, m in enumerate( machines ) : 
 60
 61		# default to the index of the machine in the list
 62		m_gid = get_gid( m_idx, m.isoclass )
 63
 64		# give the states from this machine a class name
 65		m_classes = { 
 66			f"m_{m_idx}", 
 67			f"isoclass_{m.isoclass}" 
 68		}
 69
 70		added_states = []
 71		
 72		# create a state and extend our existing machine to include it
 73		for s_idx, state in enumerate( m.states ) :
 74
 75			isomorphs=set()
 76			if m.isoclass is not None and m.isoclass in isomorphic_groups :
 77				for other_idx in isomorphic_groups[ m.isoclass ] :
 78					
 79					if other_idx == m_idx : 
 80						continue
 81					
 82					other_m = machines[ other_idx ]
 83					isomorphs.add(  
 84						rename_state( 
 85							other_m.states[ s_idx ].name, 
 86							get_gid( other_idx, other_m.isoclass ) )
 87					)
 88
 89			added_states.append( 
 90				CausalState( 
 91					name=rename_state(state.name, m_gid),
 92					classes=( m_classes | state.classes ),
 93					isomorphs=isomorphs
 94				) 
 95			)
 96
 97		machine.extend_states( added_states )
 98
 99		added_transitions = []
100
101		# add all of the transitions from the machine
102		for tr in m.transitions :
103
104			# get the names of the states for the transition
105			origin_state_name = rename_state( m.states[ tr.origin_state_idx ].name, m_gid )
106			target_state_name = rename_state( m.states[ tr.target_state_idx ].name, m_gid )
107
108			# idx of the symbol remaped to this machines alphabet list
109			new_symbol_idx = machine.symbol_idx_map[ m.alphabet[ tr.symbol_idx ] ]
110
111			# create and add the new transition
112			added_transitions.append( Transition(
113				origin_state_idx=machine.state_idx_map[ origin_state_name ],
114				target_state_idx=machine.state_idx_map[ target_state_name ],
115				prob=tr.prob,
116				symbol_idx=new_symbol_idx
117			) )
118
119		machine.extend_transitions( added_transitions )
120
121		# Add connector transitions, and adjust transition probabilities to sum to 1
122
123		# the name of the state that is the entry point to this group from the connector
124		m_entry_state_name = rename_state( m.states[ m.start_state ].name, m_gid )
125
126		# get the index of the entry state for this machine
127		m_entry_state_idx = machine.state_idx_map[ m_entry_state_name ]
128
129
130		# Get the within group transitions from m's entry state
131		# ( the probabilities will need to be adjusted )
132		transition_ids_from_m_entry = set()
133		for i, tr in enumerate( machine.transitions ) : 
134			if tr.origin_state_idx == m_entry_state_idx :
135				transition_ids_from_m_entry.add( i )
136		
137		n_from_entry = len( transition_ids_from_m_entry )
138
139		# Pr of staying in this group is distributed over the within group outgoing edges from the entry state 
140		for i in transition_ids_from_m_entry :
141
142			machine.transitions[ i ] = Transition(
143				origin_state_idx=machine.transitions[ i ].origin_state_idx,
144				target_state_idx=machine.transitions[ i ].target_state_idx,
145				prob=mode_residency_factor / n_from_entry,
146				symbol_idx=machine.transitions[ i ].symbol_idx
147			)
148
149		# from m's entry state back to connector
150		escape_pr = 1.0 - mode_residency_factor
151
152		machine.extend_transitions( transitions=[
153			Transition(
154				origin_state_idx=m_entry_state_idx,
155				target_state_idx=machine.start_state,
156				prob=escape_pr,
157				symbol_idx=machine.symbol_idx_map[ exit_symbol ]
158			)
159		] )
160
161		# from the connector to m's entry state
162		machine.extend_alphabet( alphabet=[ enter_symbols[ m_idx ] ] )
163		
164		machine.extend_transitions( transitions=[
165			Transition(
166				origin_state_idx=machine.start_state,
167				target_state_idx=m_entry_state_idx,
168				prob=( 1.0 / n_groups ),
169				symbol_idx=machine.symbol_idx_map[ enter_symbols[ m_idx ] ]
170			)
171		] )
172
173	return machine
def star( exit_symbol: str, enter_symbols: list[str], normal_symbols: list[str], n_modes: int = 7, n_isomorphic: int = 2, randomness: float = 0.3, connectedness: float = 0.5, residency_factor: float = 0.5, n_normal_symbols: int = 4, t_states_per_machine: int = 17) -> amachine.am_hmm.HMM:
176def star(
177	exit_symbol          : str,
178	enter_symbols        : list[str],
179	normal_symbols       : list[str],
180	n_modes              : int = 7,
181	n_isomorphic         : int = 2,
182	randomness           : float = 0.3,
183	connectedness        : float = 0.5,
184	residency_factor     : float = 0.5,
185	n_normal_symbols     : int = 4,
186	t_states_per_machine : int = 17 )  -> HMM :
187
188	if len( normal_symbols ) < n_normal_symbols*n_isomorphic :
189		raise ValueError( "Must have at least n_normal_symbols*n_isomorphic normal symbols" )
190
191	if len( enter_symbols ) < n_modes*n_isomorphic :
192		raise ValueError( "Must have at least n_modes*n_isomorphic enter symbols" )
193
194	alphabet     = [ f"{normal_symbols[i]}" for i in range( 0, n_normal_symbols            ) ]
195	iso_alphabet = [ f"{normal_symbols[i]}" for i in range( n_normal_symbols, n_normal_symbols*2  ) ]
196
197	random_machines = []
198
199	for i in range( n_modes ) :
200
201		m = random_machine( 
202			n_states=t_states_per_machine, 
203			symbols=alphabet, 
204			randomness=randomness,
205			connectedness=connectedness ) 
206		
207		m.collapse_to_largest_strongly_connected_subgraph()
208		m_iso = isomorphic_to( m, alphabet=iso_alphabet )
209
210		m.isoclass     = f"{i}"
211		m_iso.isoclass = f"{i}"
212
213		for j, state in enumerate( m.states ) : 
214
215			m.states[ j ].add_isomorph( m_iso.states[ j ].name )
216			m_iso.states[ j ].add_isomorph( m.states[ j ].name )
217
218		random_machines.append( m )
219		random_machines.append( m_iso )
220
221	mode_machine = star_join(  
222		exit_symbol=exit_symbol, 
223		enter_symbols=enter_symbols,
224		machines=random_machines,
225		mode_residency_factor=residency_factor
226	)
227
228	return mode_machine
def isomorphic_to( m: amachine.am_hmm.HMM, alphabet: list[str], decorator: str = '@') -> amachine.am_hmm.HMM:
231def isomorphic_to( 
232	m : HMM, 
233	alphabet : list[str],
234	decorator : str = '@' ) -> HMM :
235
236	# make sure there are enough symbols
237	if len( alphabet ) < len( m.alphabet ) :
238		raise ValueError( "Not enough symbols in the alphabet" )
239
240	# take the as much of them as needed
241	alphabet_used = alphabet[ 0 : len( m.alphabet ) ]
242
243	states = [ 
244		CausalState( 
245			name=f"{s.name}{decorator}",
246			classes=copy.deepcopy( s.classes )
247		)
248		for s in m.states
249	] 
250
251	return HMM(
252		states=states,
253		transitions=copy.deepcopy( m.transitions ),
254		start_state=0,
255		alphabet=alphabet_used
256	)
def random_machine( n_states: int, symbols: list[str], connectedness, randomness) -> amachine.am_hmm.HMM:
258def random_machine( 
259	n_states : int, 
260	symbols  : list[str], 
261	connectedness,
262	randomness )  -> HMM  :
263
264	states=[
265		CausalState( name=f"{i}" )
266		for i in range( n_states  )
267	] 
268
269	n_symbols = len( symbols )
270	transitions = []
271
272	for state_idx, state in enumerate( states ) : 
273
274		n_transitions = sum( random.random() < connectedness for _ in range( n_symbols - 1 ) ) + 1
275		transition_to = random.sample( range( n_states ), n_transitions )
276
277		transition_probabilities = exp_uniform_blend( n=n_transitions, alpha=randomness )
278		transition_symbols_indices = random.sample( range( n_symbols ), n_transitions )
279
280		for i, p in enumerate( transition_probabilities ) :
281			transitions.append(
282				Transition(
283					origin_state_idx=state_idx,
284					target_state_idx=transition_to[ i ],
285					prob=p,
286					symbol_idx=transition_symbols_indices[ i ]
287				)
288			) 
289
290	return HMM( 
291		states=states,
292		transitions=transitions,
293		start_state=0,
294		alphabet=symbols.copy()
295	)