Coverage for MPT/graph/kinetic_network.py: 87%
93 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 10:44 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 10:44 +0200
1import msmhelper as mh
2import networkx as nx
3import numpy as np
4import math
5import prettypyplot as pplt
6from .curved_edges import curved_edges
7from fa2_modified import ForceAtlas2
8from scipy.spatial import distance_matrix
9from matplotlib import pyplot as plt
10from matplotlib.colors import to_hex
11from matplotlib.collections import LineCollection
12from matplotlib.colors import Normalize
14pplt.use_style(figsize=1.8, figratio=1)
16USE_FA2 = True
17DRAW_FLUX = True
20def assign_color(qoft, states, traj, levels):
21 states_qoft = np.array([1 - np.mean(qoft[traj == state]) for state in states])
22 norm = Normalize(vmin=states_qoft.min(), vmax=states_qoft.max())
23 states_bin = np.array([_bin(q, levels) for q in states_qoft])
24 states_bin = np.array([_bin(norm(q), levels) for q in states_qoft])
25 colors_list = [_color(q_bin, levels) for q_bin in states_bin]
26 return colors_list
29def _color(val, levels):
30 cmap = plt.get_cmap("plasma", levels)
31 return to_hex(
32 cmap(val),
33 )
36def _bin(val, levels):
37 # get bin
38 bins = np.linspace(0, 1, levels + 1)
40 for rlower, rhigher in zip(bins[:-1], bins[1:]): 40 ↛ 44line 40 didn't jump to line 44 because the loop on line 40 didn't complete
41 if rlower <= val <= rhigher:
42 return rlower
44 return bins[-1]
47def get_luminance(hex_color):
48 color = hex_color[1:]
49 hex_red = int(color[0:2], base=16)
50 hex_green = int(color[2:4], base=16)
51 hex_blue = int(color[4:6], base=16)
52 return hex_red * 0.2126 + hex_green * 0.7152 + hex_blue * 0.0722
55def draw_knetwork(traj, tlag, qoft, out, u=0, f=0, set_min_node_size=True):
56 _, ax = plt.subplots()
57 tmat, states = mh.msm.estimate_markov_model(traj, tlag)
58 n_nodes = len(np.unique(states))
59 color_list = assign_color(qoft, states, traj, levels=10)
61 # get detailed balance
62 pop_eq = mh.msm.equilibrium_population(tmat)
63 mat = tmat * pop_eq[:, np.newaxis]
64 mat = 0.5 * (mat + mat.T)
66 # prepare mats for networkx
67 mat[np.diag_indices_from(mat)] = 0
68 mat[mat < 2e-5] = 0
70 # node size
71 node_size = 1000 * np.log(pop_eq + 1)
73 # set minimum node size
74 if set_min_node_size: 74 ↛ 81line 74 didn't jump to line 81 because the condition on line 74 was always true
75 node_size = np.where(
76 node_size < (np.min(node_size) + np.max(node_size)) / 2,
77 0.7 * (np.min(node_size) + np.max(node_size)) / 2,
78 node_size,
79 )
81 graph = nx.from_numpy_array(mat, create_using=nx.Graph)
83 # get position
84 # initial guess of simple spring model
85 pos = nx.spring_layout(
86 graph,
87 fixed=None,
88 iterations=1000,
89 threshold=1e-4,
90 scale=0.1,
91 weight="weight",
92 )
93 if USE_FA2: 93 ↛ 128line 93 didn't jump to line 128 because the condition on line 93 was always true
94 # improve pos by forceatlas2
95 forceatlas2 = ForceAtlas2(
96 adjustSizes=False,
97 verbose=False,
98 strongGravityMode=True,
99 scalingRatio=1000,
100 gravity=0.0,
101 )
103 pos = forceatlas2.forceatlas2_networkx_layout(
104 graph,
105 pos=pos,
106 iterations=1000,
107 )
108 coords2D = np.asarray(list(pos.values()))
110 # rotate network so that the folded basin - native basin axis is parallel to the x axis
111 if u != 0 and f != 0: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true
112 coords_u = coords2D[u, :]
113 coords_f = coords2D[f, :]
114 a = np.mean(coords_u[:, 0]) - np.mean(coords_f[:, 0])
115 b = np.mean(coords_u[:, 1]) - np.mean(coords_f[:, 1])
116 theta = math.atan2(b, a)
117 else:
118 theta = 0
119 rotated_coords = []
120 for i in range(n_nodes):
121 x = coords2D[i, 0] * math.cos(-theta) - coords2D[i, 1] * math.sin(-theta)
122 y = coords2D[i, 0] * math.sin(-theta) + coords2D[i, 1] * math.cos(-theta)
123 rotated_coords.append((x, y))
124 coords2D = rotated_coords
125 keys = list(pos.keys())
126 pos = dict(zip(keys, coords2D))
128 if DRAW_FLUX: 128 ↛ 141line 128 didn't jump to line 141 because the condition on line 128 was always true
129 edge_width = 0.1 + 300 * np.array(
130 [graph[i][j]["weight"] for i, j in graph.edges],
131 )
132 curves = curved_edges(graph, pos)
133 lc = LineCollection(
134 curves,
135 color="black",
136 linewidth=edge_width,
137 alpha=1,
138 )
139 ax.add_collection(lc)
141 if not DRAW_FLUX: 141 ↛ 143line 141 didn't jump to line 143 because the condition on line 141 was never true
142 # create directed graph to draw edges
143 digraph = nx.from_numpy_array(tmat, create_using=nx.DiGraph)
144 edge_width = 0.2 + 5 * np.array(
145 [digraph[i][j]["weight"] for i, j in digraph.edges],
146 )
147 nx.draw_networkx_edges(
148 digraph,
149 arrowstyle="-",
150 pos=pos,
151 connectionstyle="arc3,rad=0.4",
152 width=edge_width,
153 edge_color="black",
154 node_size=node_size,
155 arrowsize=3,
156 )
158 nx.draw_networkx_nodes(
159 graph,
160 pos=pos,
161 node_color=color_list,
162 node_size=node_size,
163 linewidths=0.55,
164 edgecolors="black",
165 )
166 # write node labels
167 for node_idx, (x, y) in pos.items():
168 luminance = get_luminance(color_list[node_idx])
169 if luminance < 140 and set_min_node_size:
170 c_text = "white"
171 else:
172 c_text = "black"
173 pplt.text(
174 x, y, states[node_idx] + 1, contour=False, fontsize="medium", color=c_text
175 )
176 # calc limits
177 lims = np.array(
178 [
179 (
180 x - max(node_size),
181 x + max(node_size),
182 y - max(node_size),
183 y + max(node_size),
184 )
185 for n, (x, y) in pos.items()
186 ]
187 )
188 ax.set_xlim(lims[:, 0].min(), lims[:, 1].max())
189 ax.set_ylim(lims[:, 2].min(), lims[:, 3].max())
191 ax.set_axis_off()
192 plt.tight_layout()
193 pplt.savefig(out)