amachine.am_vis
1from pathlib import Path 2import graphviz 3from matplotlib import colormaps 4from matplotlib.colors import to_hex 5 6def create_digraph(engine="dot"): 7 8 engine_configs = { 9 "dot": { 10 'rankdir': 'LR', 11 'ranksep': '1.0', 12 'nodesep': '0.4', 13 'splines': 'spline', 14 'constraint' : 'true', 15 'concentrate': 'false' 16 }, 17 "neato": { 18 'overlap': 'scale', 19 'overlap_scaling': '-4', 20 'esep': '+2.5', 21 'sep': '+1.75', 22 'model': 'shortpath', 23 'damping': '0.85', 24 'epsilon': '0.00001', 25 'maxiter': '1000000', 26 'start': '5', 27 }, 28 "fdp": { 29 'overlap': 'prism', 30 'sep': '+1.5', 31 'K': '1.0', 32 'splines': 'true', 33 'len' : '3.0', 34 'maxiter': '5000' 35 } 36 } 37 38 graph_attr = engine_configs.get(engine, {}) 39 40 node_attr = { 41 'shape': 'box', 42 'style': 'rounded, filled', 43 'fillcolor': 'lightblue', 44 'fontname': 'Helvetica' 45 } 46 47 edge_attr = { 48 'penwidth': '1.2', 49 'color': 'gray40' 50 } 51 52 return graphviz.Digraph( 53 engine=engine, 54 graph_attr=graph_attr, 55 node_attr=node_attr, 56 edge_attr=edge_attr 57 ) 58 59 60def draw_graph( 61 aM, 62 output_dir : Path, 63 title="am_graph", 64 view=True, 65 subgraphs=None, 66 engine="dot" ): 67 68 GV = create_digraph(engine=engine) 69 70 cmap = colormaps['Set3'] 71 72 if subgraphs : 73 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 74 75 for node in range( len(aM.states) ) : 76 77 if subgraphs : 78 79 node_subgraph = -1 80 for i, sg in enumerate( subgraphs ) : 81 if node in sg : 82 node_subgraph = i 83 84 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 85 86 else : 87 node_color = to_hex(cmap(8)) 88 89 node_tex = aM.states[ node ].name 90 if not node_tex : 91 node_tex = str(node) 92 93 GV.node( 94 str(node), 95 label=node_tex, 96 shape='circle', 97 style='bold,filled', 98 fillcolor=node_color, 99 color='black', 100 width='0.5' ) 101 102 103 for tr in aM.transitions : 104 105 u = tr.origin_state_idx 106 v = tr.target_state_idx 107 108 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 109 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 110 111 html_label = ( 112 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 113 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 114 f'</TABLE>>' 115 ) 116 117 GV.edge( 118 str(u), str(v), 119 label=html_label, 120 fontsize='10', 121 fontname='Times-Italic', 122 labelfloat="true", 123 ) 124 125 GV.render(title, directory=output_dir, view=view, format='pdf', cleanup=True) 126 print(f"Graph rendered as {title}.pdf")
def
create_digraph(engine='dot'):
8def create_digraph(engine="dot"): 9 10 engine_configs = { 11 "dot": { 12 'rankdir': 'LR', 13 'ranksep': '1.0', 14 'nodesep': '0.4', 15 'splines': 'spline', 16 'constraint' : 'true', 17 'concentrate': 'false' 18 }, 19 "neato": { 20 'overlap': 'scale', 21 'overlap_scaling': '-4', 22 'esep': '+2.5', 23 'sep': '+1.75', 24 'model': 'shortpath', 25 'damping': '0.85', 26 'epsilon': '0.00001', 27 'maxiter': '1000000', 28 'start': '5', 29 }, 30 "fdp": { 31 'overlap': 'prism', 32 'sep': '+1.5', 33 'K': '1.0', 34 'splines': 'true', 35 'len' : '3.0', 36 'maxiter': '5000' 37 } 38 } 39 40 graph_attr = engine_configs.get(engine, {}) 41 42 node_attr = { 43 'shape': 'box', 44 'style': 'rounded, filled', 45 'fillcolor': 'lightblue', 46 'fontname': 'Helvetica' 47 } 48 49 edge_attr = { 50 'penwidth': '1.2', 51 'color': 'gray40' 52 } 53 54 return graphviz.Digraph( 55 engine=engine, 56 graph_attr=graph_attr, 57 node_attr=node_attr, 58 edge_attr=edge_attr 59 )
def
draw_graph( aM, output_dir: pathlib.Path, title='am_graph', view=True, subgraphs=None, engine='dot'):
62def draw_graph( 63 aM, 64 output_dir : Path, 65 title="am_graph", 66 view=True, 67 subgraphs=None, 68 engine="dot" ): 69 70 GV = create_digraph(engine=engine) 71 72 cmap = colormaps['Set3'] 73 74 if subgraphs : 75 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 76 77 for node in range( len(aM.states) ) : 78 79 if subgraphs : 80 81 node_subgraph = -1 82 for i, sg in enumerate( subgraphs ) : 83 if node in sg : 84 node_subgraph = i 85 86 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 87 88 else : 89 node_color = to_hex(cmap(8)) 90 91 node_tex = aM.states[ node ].name 92 if not node_tex : 93 node_tex = str(node) 94 95 GV.node( 96 str(node), 97 label=node_tex, 98 shape='circle', 99 style='bold,filled', 100 fillcolor=node_color, 101 color='black', 102 width='0.5' ) 103 104 105 for tr in aM.transitions : 106 107 u = tr.origin_state_idx 108 v = tr.target_state_idx 109 110 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 111 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 112 113 html_label = ( 114 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 115 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 116 f'</TABLE>>' 117 ) 118 119 GV.edge( 120 str(u), str(v), 121 label=html_label, 122 fontsize='10', 123 fontname='Times-Italic', 124 labelfloat="true", 125 ) 126 127 GV.render(title, directory=output_dir, view=view, format='pdf', cleanup=True) 128 print(f"Graph rendered as {title}.pdf")