Source code for cocoatree.visualization

"""Module to visualize phylogenetic trees along with sectors"""

# User provided file:
# - phylogenetic tree in newick format
# - multiple sequence alignment used to generate the tree in fasta format
# - annotation table in csv format

# Import necessary packages
from ete3 import ProfileFace, TreeStyle, NodeStyle, TextFace, \
    add_face_to_node, SeqMotifFace, RectFace
from pandas.api.types import is_numeric_dtype  # type: ignore
import pandas as pd
import numpy as np
from PyQt5 import QtGui
import matplotlib.colors as colors
import matplotlib.cm as cmx
import matplotlib.pyplot as plt

from .msa import compute_seq_identity


def _annot_to_color(attribute, tree, df_annot, cmap='jet'):
    """
    Reads in the attributes specified by the user in the annotation csv file
    and attributes a color palette for each.

    Parameters
    ----------
    tree : ete3's tree object,
            as imported by io.load_tree_ete3()

    attributes : list of column names to grab

    df_annot : pandas dataframe of the annotation file

    Returns
    -------
    att_dict : dictionnary in which keys are the sequence IDs and the values
            are the colors associated with it

    color_dict : dictionnary in which keys are the attribute's categories and
            the values are the colors associated to each category
    """
    id_lst = tree.get_leaf_names()
    df_annot = df_annot.fillna('unknown')
    if is_numeric_dtype(df_annot['Seq_ID']):
        df_annot['Seq_ID'] = df_annot['Seq_ID'].astype('str')
    df_annot = df_annot[df_annot['Seq_ID'].isin(id_lst)]

    att_dict = {}
    df_annot = df_annot[['Seq_ID', attribute]]

    if isinstance(cmap, str):
        color_dict = _get_color_palette(
            list(df_annot[attribute].unique()), cmap)
    else:
        color_dict = {n: cmap[n] for n in list(df_annot[attribute].unique())}

    df_annot[str(attribute + '_color')] = df_annot.apply(
        lambda row: color_dict[row[attribute]], axis=1)
    for i in range(0, len(df_annot['Seq_ID'])):
        row = df_annot.iloc[i].tolist()
        att_dict[row[0]] = row[2]

    return att_dict, color_dict


def _generate_colors_from_colormaps(n_colors, cmap="jet", as_hex=True):
    """
    Generate a list of n colors from colormap
    """

    colormap = plt.get_cmap(str(cmap))
    indx = np.linspace(0, 1, n_colors)
    indexed_colors = [colormap(i) for i in indx]
    if as_hex:
        indexed_colors = [colors.to_hex(i) for i in indexed_colors]
    return indexed_colors


# TO DO: allow user to choose which color to use for 'unknown'
# (currently: white by default)
def _get_color_palette(values, cmap):

    nvals = len(values)
    colors = _generate_colors_from_colormaps(nvals, cmap=cmap, as_hex=True)

    color_dict = {}  # key = value, value = colour id
    for i in range(0, nvals):
        if values[i] == 'unknown':
            color_dict[values[i]] = '#FFFFFF'
        else:
            color_dict[values[i]] = colors[i]

    return color_dict


def _get_color_gradient(self):
    """
    Function which allows to use matplotlib colormaps in ete3 heatmap
    Adapted from:
    https://github.com/lthiberiol/virfac/blob/master/get_color_gradient.py
    """
    cNorm = colors.Normalize(vmin=0, vmax=1)
    scalarMap = cmx.ScalarMappable(norm=cNorm,
                                   cmap=plt.get_cmap(self.colorscheme))
    color_scale = []
    for scale in np.linspace(0, 1, 201):
        [r, g, b, a] = scalarMap.to_rgba(scale, bytes=True)
        color_scale.append(QtGui.QColor(r, g, b, a))

    return color_scale


[docs] def update_tree_ete3_and_return_style( tree_ete3, df_annot, sector_id=None, sector_seq=None, meta_data=None, show_leaf_name=True, fig_title='', linewidth=1, linecolor="#000000", bootstrap_style={}, tree_scale=200, metadata_colors=None, t_sector_seq=False, t_sector_heatmap=False, colormap='inferno' ): """ Update ete3 tree with sector info and attributes and return tree_style for further visualization. Parameters ---------- tree_ete3 : ete3's tree object, as imported by io.load_tree_ete3() annot_file : pandas dataframe of the annotation file sector_id : list of sector identifiers, as imported by io.load_msa() the ids must match with the tree's leaves id sector_seq : corresponding list of sector sequences to display, as imported by io.load_msa() meta_data : tuple of annotations to display (from annotation file's header) show_leaf_name : boolean, optional, default: True whether to show leaf names. linewidth : int, optional, default: 1 width of the lines in the tree linecolor : str, optional, default: "#000000" color of the lines bootstrap_style : dict, optional, `fgcolor`: color of the bootstrap node, default: "darkred" `size`: size of the bootstrap node, default: 10 `support`: int between 0 and 100, minimum support level for display tree_scale : int, optional, default: 200 sets the scale of the tree in ETE3: the higher, the larger the tree will be (in width) metadata_colors : dict, str, or None, optional, default: None colors for the metadata: - None: generates automatically the colors - str: uses a Matplotlib colormap to generate the colors - dict: specifies colors for each matadata entry {key: color} fig_title : figure title (str) t_sector_seq : boolean, whether to show the sequences of the sector t_sector_heatmap : boolean, whether to add a heatmap of the identity matrix between sector sequences Returns ------- tree_style : TreeStyle class from ete3 column_end : int, the number of columns after the tree. If you want to plot anything else alongside the tree, the column number should be equal to this value. """ tree_style = TreeStyle() tree_style.scale = tree_scale tree_style.layout_fn = [] # tree_style.branch_vertical_margin = 20 tree_style.show_leaf_name = show_leaf_name # Add bootstrap support NodeStyle boot_style = NodeStyle() boot_style["fgcolor"] = \ bootstrap_style["fgcolor"] if "fgcolor" in bootstrap_style \ else "darkred" boot_style["size"] = \ bootstrap_style["size"] if "size" in bootstrap_style else 10 support = \ bootstrap_style["support"] if "support" in bootstrap_style else 95 boot_style["hz_line_width"] = linewidth boot_style["vt_line_width"] = linewidth boot_style["vt_line_color"] = linecolor boot_style["hz_line_color"] = linecolor empty_style = NodeStyle() empty_style["size"] = 0 empty_style["vt_line_width"] = linewidth empty_style["hz_line_width"] = linewidth empty_style["vt_line_color"] = linecolor empty_style["hz_line_color"] = linecolor for node in tree_ete3.traverse(): if node.support >= support: node.set_style(boot_style) else: node.set_style(empty_style) column_layout = 0 col_legend_rectface = 0 if metadata_colors is None: metadata_colors = "jet" # If no metadata, do nothing if meta_data: def layout_attribute(node, column=column_layout): if node.is_leaf(): name = node.name rect_faces = [None for i in range(len(meta_data))] for i, col in enumerate(meta_data): colors, _ = _annot_to_color(col, tree_ete3, df_annot, cmap=metadata_colors) rect_faces[i] = RectFace(50, 20, fgcolor=colors[name], bgcolor=colors[name]) rect_faces[i].margin_left = 5 rect_faces[i].margin_right = 0 if i == len(meta_data) - 1: rect_faces[i].margin_right = 30 add_face_to_node(rect_faces[i], node, column=column, position='aligned') column += 1 tree_style.layout_fn.append(layout_attribute) # Add legend legend_face = [None for i in range(len(meta_data))] for i, col in enumerate(meta_data): _, col_dict = _annot_to_color(col, tree_ete3, df_annot, cmap=metadata_colors) tree_style.legend.add_face(TextFace(col, fsize=10, bold=True), column=col_legend_rectface) # otherwise text is not in front of RectFace tree_style.legend.add_face(TextFace(""), column=col_legend_rectface + 1) legend_face[i] = {key: None for key in col_dict.keys()} for key in col_dict.keys(): legend_face[i][key] = RectFace(50, 20, fgcolor=col_dict[key], bgcolor=col_dict[key]) legend_face[i][key].margin_right = 5 legend_face[i][key].margin_left = 10 tree_style.legend.add_face(legend_face[i][key], column=col_legend_rectface) tree_style.legend.add_face(TextFace(key, fsize=10), column=col_legend_rectface + 1) col_legend_rectface += 2 column_layout += len(meta_data) if meta_data else 0 if t_sector_seq: tree_style, column_layout = add_sector_sequences_to_tree( tree_style, tree_ete3, sector_id, sector_seq, column_start=column_layout) if t_sector_heatmap: tree_style, column_layout = add_heatmap_to_tree( tree_style, tree_ete3, sector_id, sector_seq, column_start=column_layout) # Add title tree_style.title.add_face(TextFace(fig_title, fsize=20), column=0) return tree_style, column_layout
[docs] def add_sector_sequences_to_tree(tree_style, tree_ete3, sector_id, sector_seq, column_start=0): """ Add sector sequence to ETE3's tree style Parameters ---------- tree_style : ETE3's tree_style object tree_ete3 : ete3's tree object, as imported by io.load_tree_ete3() sector_id : list of sector identifiers, as imported by io.load_msa() the ids must match with the tree's leaves id sector_seq : corresponding list of sector sequences to display, as imported by io.load_msa() column_start : int, optional, default : 0 the column on which to start plotting Returns ------- tree_style : TreeStyle class from ete3 column_end : int, the number of columns after the tree. If you want to plot anything else alongside the tree, the column number should be equal to this value. """ sector_dict = { sector_id[i]: str(sector_seq[i]) for i in range(len(sector_id))} def layout_SeqMotifFace(node, column=column_start): if node.is_leaf(): if node.name in sector_dict: seq = sector_dict[node.name] else: seq = '-' * len(sector_seq[0]) seqFace = SeqMotifFace(seq, motifs=[[0, len(sector_seq[0]), "seq", 20, 20, None, None, None]], scale_factor=1) seqFace.margin_right = 30 add_face_to_node(seqFace, node, column=column, position='aligned') tree_style.layout_fn.append(layout_SeqMotifFace) column_start += 1 return tree_style, column_start
[docs] def add_heatmap_to_tree(tree_style, tree_ete3, sector_id, sector_seq, column_start=0, width=20, colormap="inferno"): """ Add heatmap to ETE3's tree style Parameters ---------- tree_style : ETE3's tree_style object tree_ete3 : ete3's tree object, as imported by io.load_tree_ete3() sector_id : list of sector identifiers, as imported by io.load_msa() the ids must match with the tree's leaves id sector_seq : corresponding list of sector sequences to display, as imported by io.load_msa() column_start : int, optional, default : 0 the column on which to start plotting width : int, optional, default : 20 the width of each square of the heatmap. If width == 20, the heatmap will be squared. colormap : str, optional, default: "inferno" any Matplotlib's colormap Returns ------- tree_style : TreeStyle class from ete3 column_end : int, the number of columns after the tree. If you want to plot anything else alongside the tree, the column number should be equal to this value. """ leaves_id = tree_ete3.get_leaf_names() nb_leaves = len(leaves_id) # allow to chose among Matplotlib's colormaps ProfileFace.get_color_gradient = _get_color_gradient # Check that sequences in the similarity matrix are ordered as in the # tree leaves and keep only sequences that are present in the tree sequences = pd.DataFrame(index=sector_id, data={"seq": sector_seq}) reordered_sequences = sequences.loc[leaves_id, "seq"].values id_mat = compute_seq_identity(reordered_sequences) # Add heatmap profile to each leaf for i, lf in enumerate(tree_ete3.iter_leaves()): lf.add_features(profile=id_mat[i]) lf.add_features(deviation=[0 for x in range(id_mat.shape[0])]) lf.add_face(ProfileFace(max_v=1, min_v=0.0, center_v=0.5, width=(nb_leaves*width), height=20, style='heatmap', colorscheme=colormap), column=column_start, position="aligned") column_start += nb_leaves*width return tree_style, column_start