Coverage for /home/deng/Projects/metatree_drawer/metatreedrawer/treeprofiler/src/utils.py: 54%
427 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:33 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-07 10:33 +0200
1from __future__ import annotations
2from treeprofiler.src import b64pickle
3from ete4.parser.newick import NewickError
4from ete4.core.operations import remove
5from ete4 import Tree, PhyloTree
6from Bio import AlignIO
7from Bio.Align import MultipleSeqAlignment
8from Bio.Align.AlignInfo import SummaryInfo
9from itertools import chain
10from distutils.util import strtobool
11import matplotlib.pyplot as plt
12import matplotlib as mpl
13import matplotlib.colors as mcolors
14import numpy as np
15from scipy import stats
16import random
17import colorsys
18import operator
19import math
20import Bio
21import re
22import sys
23from io import StringIO
25# conditional syntax calling
26operator_dict = {
27 '<':operator.lt,
28 '<=':operator.le,
29 '=':operator.eq,
30 '!=':operator.ne,
31 '>':operator.gt,
32 '>=':operator.ge,
33 }
35_true_set = {'yes', 'true', 't', 'y', '1'}
36_false_set = {'no', 'false', 'f', 'n', '0'}
38def str2bool(value, raise_exc=False):
39 if isinstance(value, str) or sys.version_info[0] < 3 and isinstance(value, basestring):
40 value = value.lower()
41 if value in _true_set:
42 return True
43 if value in _false_set:
44 return False
46 if raise_exc:
47 raise ValueError('Expected "%s"' % '", "'.join(_true_set | _false_set))
48 return None
50def str2bool_exc(value):
51 return str2bool(value, raise_exc=True)
53def check_float_array(array):
54 """
55 Checks if all elements in the array can be treated as floats
57 :param array: The array to check.
58 :return: True if all elements can be converted to floats, False otherwise.
59 """
60 try:
61 # Attempt to convert the list to a NumPy array of floats
62 np_array = np.array(array, dtype=np.float64)
63 return not np.isnan(np_array).any() # Check if the conversion resulted in any NaNs
64 except ValueError:
65 return False # Conversion to float failed, indicating non-numerical data
67def check_nan(value):
68 try:
69 return math.isnan (float(value))
70 except ValueError:
71 return False
73def counter_call(node, internal_prop, leaf_prop, datatype, operator_string, right_value):
74 pair_delimiter = "--"
75 item_seperator = "||"
76 if datatype == str:
77 counter_props = node.props.get(internal_prop)
79 if counter_props:
80 counter_datas = counter_props.split(item_seperator)
81 for counter_data in counter_datas:
82 k, v = counter_data.split(pair_delimiter)
83 if k == leaf_prop:
84 left_value = float(v)
85 return operator_dict[operator_string](left_value, float(right_value))
86 else:
87 pass
88 else:
89 return False
91 else:
92 return False
94def call(node, prop, datatype, operator_string, right_value):
95 num_operators = [ '<', '<=', '>', '>=' ]
96 if datatype == str:
97 if operator_string in num_operators:
98 return False
99 elif operator_string == 'contains':
100 left_value = node.props.get(prop)
101 if left_value:
102 return right_value in left_value
103 elif operator_string == 'in':
104 left_value = right_value
105 right_value = node.props.get(prop)
106 if right_value:
107 return left_value in right_value
108 else:
109 left_value = node.props.get(prop)
111 if left_value:
112 return operator_dict[operator_string](left_value, right_value)
114 elif datatype == float:
115 left_value = node.props.get(prop)
116 if left_value:
117 return operator_dict[operator_string](float(left_value), float(right_value))
118 else:
119 return False
121 elif datatype == list:
122 if operator_string in num_operators:
123 return False
124 elif operator_string == 'contains':
125 left_value = node.props.get(prop)
126 if left_value:
127 return right_value in left_value
130def to_code(string):
131 conditional_output = []
132 operators = [ '<', '<=', '>', '>=', '=', '!=', 'in', 'contains']
134 r = re.compile( '|'.join( '(?:{})'.format(re.escape(o)) for o in sorted(operators, reverse=True, key=len)) )
136 #codes = string.split(',')
137 # code = code.replace(",", " and ") # ',' means and
138 # code = code.replace(";", " or ") # ';' means or
139 # code = code.replace("=", " = ")
140 # code = code.replace('>', " > ")
141 # code = code.replace('>=', ' >= ')
143 condition_strings = string.split(',')
144 for condition_string in condition_strings:
145 ops = r.findall(condition_string)
146 for op in ops:
147 condition_string = re.sub(op, ' '+op+' ', condition_string)
148 left_value, op, right_value = condition_string.split(None,2)
149 conditional_output.append([left_value, op, right_value])
151 return conditional_output
153SeqRecord = Bio.SeqRecord.SeqRecord
154def get_consensus_seq(matrix_string: Path | str, threshold=0.7) -> SeqRecord:
155 #https://stackoverflow.com/questions/73702044/how-to-get-a-consensus-of-multiple-sequence-alignments-using-biopython
156 common_alignment = MultipleSeqAlignment(
157 chain(*AlignIO.parse(StringIO(matrix_string), "fasta"))
158 )
159 summary = SummaryInfo(common_alignment)
160 consensus = summary.dumb_consensus(threshold, "-")
161 return consensus
163def counter2ratio(node, prop, minimum=0.05):
164 counter_separator = '||'
165 items_separator = '--'
166 count_missing = True
167 total = 0
168 positive = 0
170 counter_props = node.props.get(prop).split(counter_separator)
171 for counter_prop in counter_props:
172 k, v = counter_prop.split('--')
173 if count_missing:
174 if not check_nan(k):
175 if strtobool(k):
176 positive = float(v)
177 total += float(v) # here consider missing data in total
178 else:
179 if not check_nan(k):
180 total += float(v) # here doesn't consider missing data in total
181 if strtobool(k):
182 positive = float(v)
184 total = int(total)
185 if total != 0:
186 ratio = positive / total
187 else:
188 ratio = 0
190 if ratio < minimum and ratio != 0: # show minimum color for too low
191 ratio = 0.05
193 return ratio
195def categorical2ratio(node, prop, all_values, minimum=0.05):
196 counter_separator = '||'
197 items_separator = '--'
198 count_missing = True
199 total = 0
200 positive = 0
201 ratios = []
203 counter_props = node.props.get(prop).split(counter_separator)
204 counter_dict = {k: v for k, v in [counter_prop.split('--') for counter_prop in counter_props]}
205 total = sum([int(v) for v in counter_dict.values()])
206 for value in all_values:
207 if value in counter_dict:
208 positive = int(counter_dict[value])
209 else:
210 positive = 0
211 ratio = positive / total
212 if ratio < minimum and ratio != 0: # show minimum color for too low
213 ratio = 0.05
214 ratios.append(ratio)
216 return ratios
218def dict_to_string(d, pair_seperator="--", item_seperator="||"):
219 return item_seperator.join([f"{key}{pair_seperator}{value}" for key, value in d.items()])
221def string_to_dict(s, pair_seperator="--", item_seperator="||"):
222 return {item.split(pair_seperator)[0]: item.split(pair_seperator)[1] for item in s.split(item_seperator)}
224# validate tree format
225class TreeFormatError(Exception):
226 pass
228def validate_tree(tree_path, input_type, internal_parser=None):
229 tree = None # Initialize tree to None
230 eteformat_flag = False
231 if input_type in ['ete', 'auto']:
232 try:
233 with open(tree_path, 'r') as f:
234 file_content = f.read()
235 tree = b64pickle.loads(file_content, encoder='pickle', unpack=False)
236 eteformat_flag = True
237 except Exception as e:
238 if input_type == 'ete':
239 raise TreeFormatError(f"Error loading tree in 'ete' format: {e}")
241 if input_type in ['newick', 'auto'] and tree is None:
242 #try:
243 tree = ete4_parse(open(tree_path), internal_parser=internal_parser)
244 #except Exception as e:
245 # raise TreeFormatError(f"Error loading tree in 'newick' format: {e}\n"
246 # "Please try using the correct parser with --internal-parser option, or check the newick format.")
248 # if tree is None:
249 # raise TreeFormatError("Failed to load the tree in either 'ete' or 'newick' format.")
251 return tree, eteformat_flag
253# parse ete4 Tree
254def get_internal_parser(internal_parser="name"):
255 if internal_parser == "name":
256 return 1
257 elif internal_parser == "support":
258 return 0
260def ete4_parse(newick, internal_parser="name"):
261 tree = PhyloTree(newick, parser=get_internal_parser(internal_parser))
262 # Correct 0-dist trees
263 has_dist = False
264 for n in tree.traverse():
265 if n.dist and float(n.dist) > 0:
266 has_dist = True
267 break
268 if not has_dist:
269 for n in tree.descendants():
270 n.dist = 1
271 return tree
273# pruning
274def taxatree_prune(tree, rank_limit='subspecies'):
275 for node in tree.traverse("preorder"):
276 rank = node.props.get('rank')
277 if rank == rank_limit:
278 children = node.children.copy()
279 for ch in children:
280 print("prune", ch.name)
281 remove(ch)
282 lca_string = node.props.get('lca')
283 if lca_string:
284 lca_dict = string_to_dict(lca_string)
285 if lca_dict:
286 lca = lca_dict.get(rank_limit, None)
287 if lca:
288 node.name = lca
289 children = node.children.copy()
290 for ch in children:
291 print("prune", ch.name)
292 remove(ch)
293 return tree
295def conditional_prune(tree, conditions_input, prop2type):
296 conditional_output = []
297 for line in conditions_input:
298 single_one = to_code(line)
300 conditional_output.append(single_one)
302 ex = False
303 while not ex:
304 ex = True
305 for n in tree.traverse():
306 if not n.is_root:
307 final_call = False
308 for or_condition in conditional_output:
309 for condition in or_condition:
310 op = condition[1]
311 if op == 'in':
312 value = condition[0]
313 prop = condition[2]
314 datatype = prop2type.get(prop)
315 final_call = call(n, prop, datatype, op, value)
316 elif ":" in condition[0]:
317 internal_prop, leaf_prop = condition[0].split(':')
318 value = condition[2]
319 datatype = prop2type[internal_prop]
320 final_call = counter_call(n, internal_prop, leaf_prop, datatype, op, value)
321 else:
322 prop = condition[0]
323 value = condition[2]
324 prop = condition[0]
325 value = condition[2]
326 datatype = prop2type.get(prop)
327 final_call = call(n, prop, datatype, op, value)
328 if final_call == False:
329 break
330 else:
331 continue
332 if final_call:
333 n.detach()
334 ex = False
335 else:
336 pass
337 # else:
338 # if n.dist == 0:
339 # n.dist = 1
340 return tree
343 #array = [n.props.get(prop) if n.props.get(prop) else 'NaN' for n in nodes]
344 array = [n.props.get(prop) for n in nodes if n.props.get(prop) ]
345 return array
347def tree_prop_array(node, prop, leaf_only=False, numeric=False):
348 array = []
349 if leaf_only:
350 for n in node.leaves():
351 prop_value = n.props.get(prop)
352 if prop_value is not None:
353 # Check if the property value is a set
354 if isinstance(prop_value, set):
355 # Extract elements from the set
356 array.extend(prop_value)
357 else:
358 # Directly append the property value
359 array.append(prop_value)
360 else:
361 for n in node.traverse():
362 prop_value = n.props.get(prop)
363 if prop_value is not None:
365 # Check if the property value is a set
366 if isinstance(prop_value, set):
367 # Extract elements from the set
368 array.extend(prop_value)
369 else:
370 if numeric:
371 if prop_value == 'NaN':
372 array.append(np.nan)
373 else:
374 array.append(prop_value)
375 else:
376 # Directly append the property value
377 array.append(prop_value)
378 return array
380def children_prop_array(nodes, prop):
381 array = []
382 for n in nodes:
383 prop_value = n.props.get(prop)
384 if prop_value is not None:
385 # Check if the property value is a set
386 if isinstance(prop_value, set):
387 # Extract elements from the set
388 array.extend(prop_value)
389 else:
390 # Directly append the property value
391 array.append(prop_value)
392 return array
394def children_prop_array_missing(nodes, prop):
395 """replace empty to missing value 'NaN' """
396 array = [n.props.get(prop) if n.props.get(prop) else 'NaN' for n in nodes]
397 #array = [n.props.get(prop) for n in nodes if n.props.get(prop) ]
398 return array
400def convert_to_int_or_float(column):
401 """
402 Convert a column to integer if possible, otherwise convert to float64.
404 Args:
405 column (list): The input column data.
407 Returns:
408 np.ndarray: Array converted to integer or float64.
409 """
410 np_column = np.array(column)
412 # Try converting to integer
413 try:
414 return np_column.astype(np.int64)
415 except ValueError:
416 # If conversion to integer fails, convert to float64
417 return np_column.astype(np.float64)
419def flatten(nasted_list):
420 """
421 input: nasted_list - this contain any number of nested lists.
422 ------------------------
423 output: list_of_lists - one list contain all the items.
424 """
426 list_of_lists = []
427 for item in nasted_list:
428 if type(item) == list:
429 list_of_lists.extend(item)
430 else:
431 list_of_lists.extend(nasted_list)
432 return list_of_lists
434def random_color(h=None, l=None, s=None, num=None, sep=None, seed=None):
435 """Return the RGB code of a random color.
437 Hue (h), Lightness (l) and Saturation (s) of the generated color
438 can be specified as arguments.
439 """
440 def rgb2hex(rgb):
441 return '#%02x%02x%02x' % rgb
443 def hls2hex(h, l, s):
444 return rgb2hex( tuple([int(x*255) for x in colorsys.hls_to_rgb(h, l, s)]))
446 if not h:
447 if seed:
448 random.seed(seed)
449 color = 1.0 / random.randint(1, 360)
450 else:
451 color = h
453 if not num:
454 n = 1
455 sep = 1
456 else:
457 n = num
459 if not sep:
460 n = num
461 sep = (1.0/n)
463 evenly_separated_colors = [color + (sep*n) for n in range(n)]
465 rcolors = []
466 for h in evenly_separated_colors:
467 if not s:
468 s = 0.5
469 if not l:
470 l = 0.5
471 rcolors.append(hls2hex(h, l, s))
473 if num:
474 return rcolors
475 else:
476 return rcolors[0]
478def assign_color_to_values(values, paired_colors):
479 """Assigns colors to values, either from a predefined list or generates new ones."""
480 color_dict = {}
481 if len(values) <= len(paired_colors):
482 # Use predefined colors if enough are available
483 color_dict = {val: paired_colors[i] for i, val in enumerate(values)}
484 else:
485 color_dict = assign_colors(values, cmap_name='terrain')
487 return dict(sorted(color_dict.items()))
489def rgba_to_hex(rgba):
490 """Convert RGBA to Hexadecimal."""
491 return '#{:02x}{:02x}{:02x}'.format(int(rgba[0]*255), int(rgba[1]*255), int(rgba[2]*255))
493def assign_colors(variables, cmap_name='tab20'):
494 """Assigns colors to variables using a matplotlib colormap."""
495 cmap = plt.cm.get_cmap(cmap_name, len(variables)) # Get the colormap
496 colors = [rgba_to_hex(cmap(i)) for i in range(cmap.N)] # Generate colors in hex format
497 #random.shuffle(colors)
498 return dict(zip(variables, colors))
501def build_color_gradient(n_colors, colormap_name="viridis"):
502 """
503 Build a color gradient based on the specified matplotlib colormap.
505 Parameters:
506 n_colors (int): Number of distinct colors to include in the gradient.
507 colormap_name (str): Name of the matplotlib colormap to use. "viridis" # Replace with "plasma", "inferno", "magma", etc., as needed
509 Returns:
510 dict: A dictionary mapping indices to colors in the specified colormap.
511 """
512 cmap = plt.get_cmap(colormap_name)
513 indices = np.linspace(0, 1, n_colors)
514 color_gradient = {i: mcolors.rgb2hex(cmap(idx)) for i, idx in enumerate(indices, 1)}
515 return color_gradient
517def build_custom_gradient(n_colors, min_color, max_color, mid_color=None):
518 """
519 Build a color gradient between two specified colors.
521 Parameters:
522 n_colors (int): Number of distinct colors to include in the gradient.
523 min_color (str): Hex code or named color for the start of the gradient.
524 max_color (str): Hex code or named color for the end of the gradient.
526 Returns:
527 dict: A dictionary mapping indices to colors in the generated gradient.
528 """
529 # Convert min and max colors to RGB
530 min_rgb = mcolors.to_rgb(min_color)
531 max_rgb = mcolors.to_rgb(max_color)
532 mid_rgb = mcolors.to_rgb(mid_color) if mid_color else None
534 color_gradient = {}
535 # Determine if we're using a mid_color and split the range accordingly
536 if mid_color:
537 # Halfway point for the gradient transition
538 mid_point = n_colors // 2
540 for i in range(1, n_colors + 1):
541 if i <= mid_point:
542 # Transition from min_color to mid_color
543 interpolated_rgb = [(mid_c - min_c) * (i - 1) / (mid_point - 1) + min_c for min_c, mid_c in zip(min_rgb, mid_rgb)]
544 else:
545 # Transition from mid_color to max_color
546 interpolated_rgb = [(max_c - mid_c) * (i - mid_point - 1) / (n_colors - mid_point - 1) + mid_c for mid_c, max_c in zip(mid_rgb, max_rgb)]
547 color_gradient[i] = mcolors.to_hex(interpolated_rgb)
548 else:
549 # If no mid_color, interpolate between min_color and max_color directly
550 for i in range(1, n_colors + 1):
551 interpolated_rgb = [(max_c - min_c) * (i - 1) / (n_colors - 1) + min_c for min_c, max_c in zip(min_rgb, max_rgb)]
552 color_gradient[i] = mcolors.to_hex(interpolated_rgb)
554 return color_gradient
556def clear_extra_features(forest, features):
557 features = set(features) | {'name', 'dist', 'support'}
558 for tree in forest:
559 for n in tree.traverse():
560 for f in set(n.props) - features:
561 if f not in features:
562 n.del_prop(f)
564 for key, value in n.props.items():
565 # Check if the value is a set
566 if isinstance(value, set):
567 # Convert the set to a string representation
568 # You can customize the string conversion as needed
569 n.props[key] = ','.join(map(str, value))
571def add_suffix(name, suffix, delimiter='_'):
572 return str(name) + delimiter + str(suffix)
574def normalize_values(values, normalization_method="min-max"):
575 """
576 Normalizes a list of numeric values using the specified method.
578 Parameters:
579 - values: List of elements to be normalized.
580 - normalization_method: String indicating the normalization method.
581 Options are "min-max", "mean-norm", and "z-score".
583 Returns:
584 - A numpy array of normalized values.
585 """
587 def try_convert_to_float(values):
588 """Attempts to convert values to float, raises error for non-convertible values."""
589 converted = []
590 for v in values:
591 try:
592 if v.lower() != 'nan': # Assuming 'nan' is used to denote missing values
593 converted.append(float(v))
594 else:
595 converted.append(np.nan) # Convert 'nan' string to numpy NaN for consistency
596 except ValueError:
597 raise ValueError(f"Cannot treat value '{v}' as a number.")
598 return np.array(converted)
600 numeric_values = try_convert_to_float(values)
601 valid_values = numeric_values[~np.isnan(numeric_values)]
603 if normalization_method == "min-max":
604 normalized = (valid_values - valid_values.min()) / (valid_values.max() - valid_values.min())
605 elif normalization_method == "mean-norm":
606 normalized = (valid_values - valid_values.mean()) / (valid_values.max() - valid_values.min())
607 elif normalization_method == "z-score":
608 normalized = stats.zscore(valid_values)
609 else:
610 raise ValueError("Unsupported normalization method.")
612 return normalized
614def find_bool_representations(column, rep=True):
615 true_values = {'true', 't', 'yes', 'y', '1'}
616 false_values = {'false', 'f', 'no', 'n', '0'}
617 ignore_values = {'nan', 'none', ''} # Add other representations of NaN as needed
619 # Initialize sets to hold the representations of true and false values
620 count = 0
622 for value in column:
623 str_val = str(value).strip().lower() # Normalize the string value
624 if str_val in ignore_values:
625 continue # Skip this value
626 if rep:
627 if str_val in true_values:
628 count += 1
629 else:
630 if str_val in false_values:
631 count += 1
633 return count
636def color_gradient(c1, c2, mix=0):
637 """ Fade (linear interpolate) from color c1 (at mix=0) to c2 (mix=1) """
638 # https://stackoverflow.com/questions/25668828/how-to-create-colour-gradient-in-python
639 c1 = np.array(mpl.colors.to_rgb(c1))
640 c2 = np.array(mpl.colors.to_rgb(c2))
641 return mpl.colors.to_hex((1-mix)*c1 + mix*c2)
643def make_color_darker_log(hex_color, total, base=10):
644 """Darkens the hex color based on a logarithmic scale of the total."""
645 # Calculate darkening factor using a logarithmic scale
646 darkening_factor = math.log(1 + total, base) / 50 # Adjust base and divisor as needed
647 return make_color_darker(hex_color, darkening_factor)
649def make_color_darker(hex_color, darkening_factor):
650 """Darkens the hex color by a factor. Simplified version for illustration."""
651 # Simple darkening logic for demonstration
652 c = mcolors.hex2color(hex_color) # Convert hex to RGB
653 darker_c = [max(0, x - darkening_factor) for x in c] # Darken color
654 return mcolors.to_hex(darker_c)
656def make_color_darker_scaled(hex_color, positive, maximum, base=10, scale_factor=10, min_darkness=0.6):
657 """
658 Darkens the hex color based on the positive count, maximum count, and a scaling factor.
660 :param hex_color: The original color in hex format.
661 :param positive: The current count.
662 :param maximum: The maximum count achievable, corresponding to the darkest color.
663 :param base: The base for the logarithmic calculation, affecting darkening speed.
664 :param scale_factor: Factor indicating how much darker the color can get at the maximum count.
665 :param min_darkness: The minimum darkness level allowed.
666 :return: The darkened hex color.
667 """
668 if positive > maximum:
669 raise ValueError("Positive count cannot exceed the maximum specified.")
671 # Calculate the normalized position of 'positive' between 0 and 'maximum'
672 normalized_position = positive / maximum if maximum != 0 else 0
674 # Calculate the logarithmic scale position
675 log_position = math.log(1 + normalized_position * (scale_factor - 1), base) / math.log(scale_factor, base)
677 # Ensure the log_position respects the min_darkness threshold
678 if log_position >= min_darkness:
679 log_position = min_darkness
681 # Convert hex to RGB
682 rgb = mcolors.hex2color(hex_color)
684 # Apply the darkening based on log_position
685 darkened_rgb = [(1 - log_position) * channel for channel in rgb]
687 return mcolors.to_hex(darkened_rgb)
689# def transform_columns(columns, treat_as_whole=True, normalization_method="min-max"):
690# transformed = defaultdict(dict)
692# def normalize(values, method):
693# if method == "min-max":
694# return (values - values.min()) / (values.max() - values.min())
695# elif method == "mean-norm":
696# return (values - values.mean()) / (values.max() - values.min())
697# elif method == "z-score":
698# return stats.zscore(values)
699# else:
700# raise ValueError("Unsupported normalization method.")
702# def try_convert_to_float(values):
703# converted = []
704# for v in values:
705# try:
706# if v != 'NaN': # Assuming 'NaN' is used to denote missing values
707# converted.append(float(v))
708# else:
709# converted.append(np.nan) # Convert 'NaN' string to numpy NaN for consistency
710# except ValueError:
711# raise ValueError(f"Cannot treat value '{v}' as a number.")
712# return np.array(converted)
714# if treat_as_whole:
715# # Attempt to concatenate all numeric columns into one array
716# all_numeric_values = np.concatenate([
717# try_convert_to_float(values) for values in columns.values()
718# ])
720# normalized_values = normalize(all_numeric_values[~np.isnan(all_numeric_values)], normalization_method)
722# start_idx = 0
723# for prop, values in columns.items():
724# num_values = len([v for v in values if v != 'NaN'])
725# transformed[prop][normalization_method] = normalized_values[start_idx:start_idx+num_values]
726# start_idx += num_values
727# else:
728# for prop, values in columns.items():
729# numeric_values = try_convert_to_float(values)
730# # Normalize only the non-NaN values
731# transformed[prop][normalization_method] = normalize(numeric_values[~np.isnan(numeric_values)], normalization_method)
733# return transformed