amachine.am_vis

  1from pathlib import Path
  2import graphviz
  3from matplotlib import colormaps
  4from matplotlib.colors import to_hex
  5
  6def create_digraph(engine="dot"):
  7
  8    engine_configs = {
  9        "dot": {
 10            'rankdir': 'LR',
 11            'ranksep': '1.0',
 12            'nodesep': '0.4',
 13            'splines': 'spline',
 14			'constraint' : 'true',
 15            'concentrate': 'false'
 16        },
 17		"neato": {
 18			'overlap': 'scale',
 19			'overlap_scaling': '-4',  
 20			'esep': '+2.5',           
 21			'sep': '+1.75',           
 22			'model': 'shortpath',
 23			'damping': '0.85',
 24			'epsilon': '0.00001',
 25			'maxiter': '1000000',
 26			'start': '5',
 27		},
 28        "fdp": {
 29            'overlap': 'prism',
 30            'sep': '+1.5',
 31            'K': '1.0',
 32            'splines': 'true',
 33			'len' : '3.0',
 34            'maxiter': '5000'
 35        }
 36    }
 37
 38    graph_attr = engine_configs.get(engine, {})
 39
 40    node_attr = {
 41        'shape': 'box',
 42        'style': 'rounded, filled',
 43        'fillcolor': 'lightblue',
 44        'fontname': 'Helvetica'
 45    }
 46    
 47    edge_attr = {
 48        'penwidth': '1.2',
 49        'color': 'gray40'
 50    }
 51
 52    return graphviz.Digraph(
 53        engine=engine,
 54        graph_attr=graph_attr,
 55        node_attr=node_attr,
 56        edge_attr=edge_attr
 57    )
 58
 59
 60def draw_graph( 
 61	aM, 
 62	output_dir : Path,
 63	title="am_graph", 
 64	view=True, 
 65	subgraphs=None,
 66	engine="dot" ):
 67
 68	GV = create_digraph(engine=engine)
 69
 70	cmap = colormaps['Set3']
 71
 72	if subgraphs :
 73		graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))]
 74
 75	for node in range( len(aM.states) ) :
 76		
 77		if subgraphs :
 78			
 79			node_subgraph = -1
 80			for i, sg in enumerate( subgraphs ) :
 81				if node in sg :
 82					node_subgraph = i
 83
 84			node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red'
 85
 86		else :
 87			node_color = to_hex(cmap(8))
 88
 89		node_tex = aM.states[ node ].name
 90		if not node_tex :
 91			node_tex = str(node)
 92
 93		GV.node(
 94			str(node), 
 95			label=node_tex, 
 96			shape='circle',
 97			style='bold,filled',
 98        	fillcolor=node_color,
 99        	color='black',
100			width='0.5' )
101
102
103	for tr in aM.transitions :
104
105		u = tr.origin_state_idx
106		v = tr.target_state_idx
107
108		pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) )
109		label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})"
110
111		html_label = (
112            f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">'
113            f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>'
114            f'</TABLE>>'
115        )
116
117		GV.edge(
118			str(u), str(v), 
119			label=html_label,
120			fontsize='10',
121			fontname='Times-Italic',
122			labelfloat="true",
123		)
124
125	GV.render(title, directory=output_dir, view=view, format='pdf', cleanup=True)
126	print(f"Graph rendered as {title}.pdf")
def create_digraph(engine='dot'):
 8def create_digraph(engine="dot"):
 9
10    engine_configs = {
11        "dot": {
12            'rankdir': 'LR',
13            'ranksep': '1.0',
14            'nodesep': '0.4',
15            'splines': 'spline',
16			'constraint' : 'true',
17            'concentrate': 'false'
18        },
19		"neato": {
20			'overlap': 'scale',
21			'overlap_scaling': '-4',  
22			'esep': '+2.5',           
23			'sep': '+1.75',           
24			'model': 'shortpath',
25			'damping': '0.85',
26			'epsilon': '0.00001',
27			'maxiter': '1000000',
28			'start': '5',
29		},
30        "fdp": {
31            'overlap': 'prism',
32            'sep': '+1.5',
33            'K': '1.0',
34            'splines': 'true',
35			'len' : '3.0',
36            'maxiter': '5000'
37        }
38    }
39
40    graph_attr = engine_configs.get(engine, {})
41
42    node_attr = {
43        'shape': 'box',
44        'style': 'rounded, filled',
45        'fillcolor': 'lightblue',
46        'fontname': 'Helvetica'
47    }
48    
49    edge_attr = {
50        'penwidth': '1.2',
51        'color': 'gray40'
52    }
53
54    return graphviz.Digraph(
55        engine=engine,
56        graph_attr=graph_attr,
57        node_attr=node_attr,
58        edge_attr=edge_attr
59    )
def draw_graph( aM, output_dir: pathlib.Path, title='am_graph', view=True, subgraphs=None, engine='dot'):
 62def draw_graph( 
 63	aM, 
 64	output_dir : Path,
 65	title="am_graph", 
 66	view=True, 
 67	subgraphs=None,
 68	engine="dot" ):
 69
 70	GV = create_digraph(engine=engine)
 71
 72	cmap = colormaps['Set3']
 73
 74	if subgraphs :
 75		graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))]
 76
 77	for node in range( len(aM.states) ) :
 78		
 79		if subgraphs :
 80			
 81			node_subgraph = -1
 82			for i, sg in enumerate( subgraphs ) :
 83				if node in sg :
 84					node_subgraph = i
 85
 86			node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red'
 87
 88		else :
 89			node_color = to_hex(cmap(8))
 90
 91		node_tex = aM.states[ node ].name
 92		if not node_tex :
 93			node_tex = str(node)
 94
 95		GV.node(
 96			str(node), 
 97			label=node_tex, 
 98			shape='circle',
 99			style='bold,filled',
100        	fillcolor=node_color,
101        	color='black',
102			width='0.5' )
103
104
105	for tr in aM.transitions :
106
107		u = tr.origin_state_idx
108		v = tr.target_state_idx
109
110		pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) )
111		label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})"
112
113		html_label = (
114            f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">'
115            f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>'
116            f'</TABLE>>'
117        )
118
119		GV.edge(
120			str(u), str(v), 
121			label=html_label,
122			fontsize='10',
123			fontname='Times-Italic',
124			labelfloat="true",
125		)
126
127	GV.render(title, directory=output_dir, view=view, format='pdf', cleanup=True)
128	print(f"Graph rendered as {title}.pdf")