Coverage for MPP/graph/kinetic_network.py: 87%

93 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-11 14:46 +0200

1import msmhelper as mh 

2import networkx as nx 

3import numpy as np 

4import math 

5import prettypyplot as pplt 

6from .curved_edges import curved_edges 

7from fa2_modified import ForceAtlas2 

8from scipy.spatial import distance_matrix 

9from matplotlib import pyplot as plt 

10from matplotlib.colors import to_hex 

11from matplotlib.collections import LineCollection 

12from matplotlib.colors import Normalize 

13 

14pplt.use_style(figsize=1.8, figratio=1) 

15 

16USE_FA2 = True 

17DRAW_FLUX = True 

18 

19 

20def assign_color(qoft, states, traj, levels): 

21 states_qoft = np.array([1 - np.mean(qoft[traj == state]) for state in states]) 

22 norm = Normalize(vmin=states_qoft.min(), vmax=states_qoft.max()) 

23 states_bin = np.array([_bin(q, levels) for q in states_qoft]) 

24 states_bin = np.array([_bin(norm(q), levels) for q in states_qoft]) 

25 colors_list = [_color(q_bin, levels) for q_bin in states_bin] 

26 return colors_list 

27 

28 

29def _color(val, levels): 

30 cmap = plt.get_cmap("plasma", levels) 

31 return to_hex( 

32 cmap(val), 

33 ) 

34 

35 

36def _bin(val, levels): 

37 # get bin 

38 bins = np.linspace(0, 1, levels + 1) 

39 

40 for rlower, rhigher in zip(bins[:-1], bins[1:]): 40 ↛ 44line 40 didn't jump to line 44 because the loop on line 40 didn't complete

41 if rlower <= val <= rhigher: 

42 return rlower 

43 

44 return bins[-1] 

45 

46 

47def get_luminance(hex_color): 

48 color = hex_color[1:] 

49 hex_red = int(color[0:2], base=16) 

50 hex_green = int(color[2:4], base=16) 

51 hex_blue = int(color[4:6], base=16) 

52 return hex_red * 0.2126 + hex_green * 0.7152 + hex_blue * 0.0722 

53 

54 

55def draw_knetwork(traj, tlag, qoft, out, u=0, f=0, set_min_node_size=True): 

56 _, ax = plt.subplots() 

57 tmat, states = mh.msm.estimate_markov_model(traj, tlag) 

58 n_nodes = len(np.unique(states)) 

59 color_list = assign_color(qoft, states, traj, levels=10) 

60 

61 # get detailed balance 

62 pop_eq = mh.msm.equilibrium_population(tmat) 

63 mat = tmat * pop_eq[:, np.newaxis] 

64 mat = 0.5 * (mat + mat.T) 

65 

66 # prepare mats for networkx 

67 mat[np.diag_indices_from(mat)] = 0 

68 mat[mat < 2e-5] = 0 

69 

70 # node size 

71 node_size = 1000 * np.log(pop_eq + 1) 

72 

73 # set minimum node size 

74 if set_min_node_size: 74 ↛ 81line 74 didn't jump to line 81 because the condition on line 74 was always true

75 node_size = np.where( 

76 node_size < (np.min(node_size) + np.max(node_size)) / 2, 

77 0.7 * (np.min(node_size) + np.max(node_size)) / 2, 

78 node_size, 

79 ) 

80 

81 graph = nx.from_numpy_array(mat, create_using=nx.Graph) 

82 

83 # get position 

84 # initial guess of simple spring model 

85 pos = nx.spring_layout( 

86 graph, 

87 fixed=None, 

88 iterations=1000, 

89 threshold=1e-4, 

90 scale=0.1, 

91 weight="weight", 

92 ) 

93 if USE_FA2: 93 ↛ 128line 93 didn't jump to line 128 because the condition on line 93 was always true

94 # improve pos by forceatlas2 

95 forceatlas2 = ForceAtlas2( 

96 adjustSizes=False, 

97 verbose=False, 

98 strongGravityMode=True, 

99 scalingRatio=1000, 

100 gravity=0.0, 

101 ) 

102 

103 pos = forceatlas2.forceatlas2_networkx_layout( 

104 graph, 

105 pos=pos, 

106 iterations=1000, 

107 ) 

108 coords2D = np.asarray(list(pos.values())) 

109 

110 # rotate network so that the folded basin - native basin axis is parallel to the x axis 

111 if u != 0 and f != 0: 111 ↛ 112line 111 didn't jump to line 112 because the condition on line 111 was never true

112 coords_u = coords2D[u, :] 

113 coords_f = coords2D[f, :] 

114 a = np.mean(coords_u[:, 0]) - np.mean(coords_f[:, 0]) 

115 b = np.mean(coords_u[:, 1]) - np.mean(coords_f[:, 1]) 

116 theta = math.atan2(b, a) 

117 else: 

118 theta = 0 

119 rotated_coords = [] 

120 for i in range(n_nodes): 

121 x = coords2D[i, 0] * math.cos(-theta) - coords2D[i, 1] * math.sin(-theta) 

122 y = coords2D[i, 0] * math.sin(-theta) + coords2D[i, 1] * math.cos(-theta) 

123 rotated_coords.append((x, y)) 

124 coords2D = rotated_coords 

125 keys = list(pos.keys()) 

126 pos = dict(zip(keys, coords2D)) 

127 

128 if DRAW_FLUX: 128 ↛ 141line 128 didn't jump to line 141 because the condition on line 128 was always true

129 edge_width = 0.1 + 300 * np.array( 

130 [graph[i][j]["weight"] for i, j in graph.edges], 

131 ) 

132 curves = curved_edges(graph, pos) 

133 lc = LineCollection( 

134 curves, 

135 color="black", 

136 linewidth=edge_width, 

137 alpha=1, 

138 ) 

139 ax.add_collection(lc) 

140 

141 if not DRAW_FLUX: 141 ↛ 143line 141 didn't jump to line 143 because the condition on line 141 was never true

142 # create directed graph to draw edges 

143 digraph = nx.from_numpy_array(tmat, create_using=nx.DiGraph) 

144 edge_width = 0.2 + 5 * np.array( 

145 [digraph[i][j]["weight"] for i, j in digraph.edges], 

146 ) 

147 nx.draw_networkx_edges( 

148 digraph, 

149 arrowstyle="-", 

150 pos=pos, 

151 connectionstyle="arc3,rad=0.4", 

152 width=edge_width, 

153 edge_color="black", 

154 node_size=node_size, 

155 arrowsize=3, 

156 ) 

157 

158 nx.draw_networkx_nodes( 

159 graph, 

160 pos=pos, 

161 node_color=color_list, 

162 node_size=node_size, 

163 linewidths=0.55, 

164 edgecolors="black", 

165 ) 

166 # write node labels 

167 for node_idx, (x, y) in pos.items(): 

168 luminance = get_luminance(color_list[node_idx]) 

169 if luminance < 140 and set_min_node_size: 

170 c_text = "white" 

171 else: 

172 c_text = "black" 

173 pplt.text( 

174 x, y, states[node_idx] + 1, contour=False, fontsize="medium", color=c_text 

175 ) 

176 # calc limits 

177 lims = np.array( 

178 [ 

179 ( 

180 x - max(node_size), 

181 x + max(node_size), 

182 y - max(node_size), 

183 y + max(node_size), 

184 ) 

185 for n, (x, y) in pos.items() 

186 ] 

187 ) 

188 ax.set_xlim(lims[:, 0].min(), lims[:, 1].max()) 

189 ax.set_ylim(lims[:, 2].min(), lims[:, 3].max()) 

190 

191 ax.set_axis_off() 

192 plt.tight_layout() 

193 pplt.savefig(out)