Coverage for MPT/graph/curved_edges.py: 95%

33 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-08 16:59 +0200

1#!/usr/bin/env python3 

2"""Plotting curved edges for undirected graphs with NetworkX<3. 

3 

4Taken from https://github.com/beyondbeneath/bezier-curved-edges-networkx 

5 

6MIT License 

7 

8Copyright (c) 2018 Geoff Sims 

9 

10Permission is hereby granted, free of charge, to any person obtaining a copy 

11of this software and associated documentation files (the "Software"), to deal 

12in the Software without restriction, including without limitation the rights 

13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

14copies of the Software, and to permit persons to whom the Software is 

15furnished to do so, subject to the following conditions: 

16 

17The above copyright notice and this permission notice shall be included in all 

18copies or substantial portions of the Software. 

19 

20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

26SOFTWARE. 

27 

28""" 

29import bezier 

30import networkx as nx 

31import numpy as np 

32 

33def curved_edges(G, pos, dist_ratio=0.2, bezier_precision=20, polarity='random'): 

34 # Get nodes into np array 

35 edges = np.array(G.edges()) 

36 l = edges.shape[0] 

37 

38 if polarity == 'random': 38 ↛ 44line 38 didn't jump to line 44 because the condition on line 38 was always true

39 # Random polarity of curve 

40 rnd = np.where(np.random.randint(2, size=l)==0, -1, 1) 

41 else: 

42 # Create a fixed (hashed) polarity column in the case we use fixed polarity 

43 # This is useful, e.g., for animations 

44 rnd = np.where(np.mod(np.vectorize(hash)(edges[:,0])+np.vectorize(hash)(edges[:,1]),2)==0,-1,1) 

45 

46 # Coordinates (x,y) of both nodes for each edge 

47 # e.g., https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key 

48 # Note the np.vectorize method doesn't work for all node position dictionaries for some reason 

49 u, inv = np.unique(edges, return_inverse = True) 

50 coords = np.array([pos[x] for x in u])[inv].reshape([edges.shape[0], 2, edges.shape[1]]) 

51 coords_node1 = coords[:,0,:] 

52 coords_node2 = coords[:,1,:] 

53 

54 # Swap node1/node2 allocations to make sure the directionality works correctly 

55 should_swap = coords_node1[:,0] > coords_node2[:,0] 

56 coords_node1[should_swap], coords_node2[should_swap] = coords_node2[should_swap], coords_node1[should_swap] 

57 

58 # Distance for control points 

59 dist = dist_ratio * np.sqrt(np.sum((coords_node1-coords_node2)**2, axis=1)) 

60 

61 # Gradients of line connecting node & perpendicular 

62 m1 = (coords_node2[:,1]-coords_node1[:,1])/(coords_node2[:,0]-coords_node1[:,0]) 

63 m2 = -1/m1 

64 

65 # Temporary points along the line which connects two nodes 

66 # e.g., https://math.stackexchange.com/questions/656500/given-a-point-slope-and-a-distance-along-that-slope-easily-find-a-second-p 

67 t1 = dist/np.sqrt(1+m1**2) 

68 v1 = np.array([np.ones(l),m1]) 

69 coords_node1_displace = coords_node1 + (v1*t1).T 

70 coords_node2_displace = coords_node2 - (v1*t1).T 

71 

72 # Control points, same distance but along perpendicular line 

73 # rnd gives the 'polarity' to determine which side of the line the curve should arc 

74 t2 = dist/np.sqrt(1+m2**2) 

75 v2 = np.array([np.ones(len(edges)),m2]) 

76 coords_node1_ctrl = coords_node1_displace + (rnd*v2*t2).T 

77 coords_node2_ctrl = coords_node2_displace + (rnd*v2*t2).T 

78 

79 # Combine all these four (x,y) columns into a 'node matrix' 

80 node_matrix = np.array([coords_node1, coords_node1_ctrl, coords_node2_ctrl, coords_node2]) 

81 

82 # Create the Bezier curves and store them in a list 

83 curveplots = [] 

84 for i in range(l): 

85 nodes = node_matrix[:,i,:].T 

86 curveplots.append(bezier.Curve(nodes, degree=3).evaluate_multi(np.linspace(0,1,bezier_precision)).T) 

87 

88 # Return an array of these curves 

89 curves = np.array(curveplots) 

90 return curves