Coverage for MPP/graph/curved_edges.py: 95%
33 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 14:46 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-11 14:46 +0200
1#!/usr/bin/env python3
2"""Plotting curved edges for undirected graphs with NetworkX<3.
4Taken from https://github.com/beyondbeneath/bezier-curved-edges-networkx
6MIT License
8Copyright (c) 2018 Geoff Sims
10Permission is hereby granted, free of charge, to any person obtaining a copy
11of this software and associated documentation files (the "Software"), to deal
12in the Software without restriction, including without limitation the rights
13to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14copies of the Software, and to permit persons to whom the Software is
15furnished to do so, subject to the following conditions:
17The above copyright notice and this permission notice shall be included in all
18copies or substantial portions of the Software.
20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26SOFTWARE.
28"""
29import bezier
30import networkx as nx
31import numpy as np
33def curved_edges(G, pos, dist_ratio=0.2, bezier_precision=20, polarity='random'):
34 # Get nodes into np array
35 edges = np.array(G.edges())
36 l = edges.shape[0]
38 if polarity == 'random': 38 ↛ 44line 38 didn't jump to line 44 because the condition on line 38 was always true
39 # Random polarity of curve
40 rnd = np.where(np.random.randint(2, size=l)==0, -1, 1)
41 else:
42 # Create a fixed (hashed) polarity column in the case we use fixed polarity
43 # This is useful, e.g., for animations
44 rnd = np.where(np.mod(np.vectorize(hash)(edges[:,0])+np.vectorize(hash)(edges[:,1]),2)==0,-1,1)
46 # Coordinates (x,y) of both nodes for each edge
47 # e.g., https://stackoverflow.com/questions/16992713/translate-every-element-in-numpy-array-according-to-key
48 # Note the np.vectorize method doesn't work for all node position dictionaries for some reason
49 u, inv = np.unique(edges, return_inverse = True)
50 coords = np.array([pos[x] for x in u])[inv].reshape([edges.shape[0], 2, edges.shape[1]])
51 coords_node1 = coords[:,0,:]
52 coords_node2 = coords[:,1,:]
54 # Swap node1/node2 allocations to make sure the directionality works correctly
55 should_swap = coords_node1[:,0] > coords_node2[:,0]
56 coords_node1[should_swap], coords_node2[should_swap] = coords_node2[should_swap], coords_node1[should_swap]
58 # Distance for control points
59 dist = dist_ratio * np.sqrt(np.sum((coords_node1-coords_node2)**2, axis=1))
61 # Gradients of line connecting node & perpendicular
62 m1 = (coords_node2[:,1]-coords_node1[:,1])/(coords_node2[:,0]-coords_node1[:,0])
63 m2 = -1/m1
65 # Temporary points along the line which connects two nodes
66 # e.g., https://math.stackexchange.com/questions/656500/given-a-point-slope-and-a-distance-along-that-slope-easily-find-a-second-p
67 t1 = dist/np.sqrt(1+m1**2)
68 v1 = np.array([np.ones(l),m1])
69 coords_node1_displace = coords_node1 + (v1*t1).T
70 coords_node2_displace = coords_node2 - (v1*t1).T
72 # Control points, same distance but along perpendicular line
73 # rnd gives the 'polarity' to determine which side of the line the curve should arc
74 t2 = dist/np.sqrt(1+m2**2)
75 v2 = np.array([np.ones(len(edges)),m2])
76 coords_node1_ctrl = coords_node1_displace + (rnd*v2*t2).T
77 coords_node2_ctrl = coords_node2_displace + (rnd*v2*t2).T
79 # Combine all these four (x,y) columns into a 'node matrix'
80 node_matrix = np.array([coords_node1, coords_node1_ctrl, coords_node2_ctrl, coords_node2])
82 # Create the Bezier curves and store them in a list
83 curveplots = []
84 for i in range(l):
85 nodes = node_matrix[:,i,:].T
86 curveplots.append(bezier.Curve(nodes, degree=3).evaluate_multi(np.linspace(0,1,bezier_precision)).T)
88 # Return an array of these curves
89 curves = np.array(curveplots)
90 return curves