Source code for dandelion.plotting._plotting

#!/usr/bin/env python
# @Author: Kelvin
# @Date:   2020-05-18 00:15:00
# @Last Modified by:   Kelvin
# @Last Modified time: 2021-02-20 09:44:55

import seaborn as sns
import pandas as pd
import numpy as np
from ..utilities._utilities import *
from ..utilities._core import *
from ..utilities._io import *
from ..tools._diversity import rarefun
from scanpy.plotting._tools.scatterplots import embedding
import matplotlib.pyplot as plt
from anndata import AnnData
import random
from adjustText import adjust_text
from plotnine import ggplot, theme_classic, aes, geom_line, xlab, ylab, options, ggtitle, labs, scale_color_manual
from scanpy.plotting import palettes
from time import sleep
import matplotlib.pyplot as plt
from itertools import combinations
from typing import Union, Sequence, Tuple, Dict
from matplotlib.axes import Axes
from matplotlib.figure import Figure


[docs]def clone_rarefaction(self: Union[AnnData, Dandelion], color: str, clone_key: Union[None, str] = None, palette: Union[None, Sequence] = None, figsize: Tuple[Union[int, float], Union[int, float]] = (6, 4), save: Union[None, str] = None) -> ggplot: """ Plots rarefaction curve for cell numbers vs clone size. Parameters ---------- self : `AnnData`, `Dandelion` `AnnData` or `Dandelion` object. color : str Column name to split the calculation of clone numbers for a given number of cells for e.g. sample, patient etc. clone_key : str, optional Column name specifying the clone_id column in metadata/obs. palette : Sequence, optional Color mapping for unique elements in color. Will try to retrieve from AnnData `.uns` slot if present. figsize : Tuple[Union[int,float], Union[int,float]] Size of plot. save : str, optional Save path. Returns ------- rarefaction curve plot. """ if self.__class__ == AnnData: metadata = self.obs.copy() elif self.__class__ == Dandelion: metadata = self.metadata.copy() if clone_key is None: clonekey = 'clone_id' else: clonekey = clone_key groups = list(set(metadata[color])) metadata = metadata[metadata['bcr_QC_pass'].isin([True, 'True'])] metadata[clonekey] = metadata[clonekey].cat.remove_unused_categories() res = {} for g in groups: _metadata = metadata[metadata[color] == g] res[g] = _metadata[clonekey].value_counts() res_ = pd.DataFrame.from_dict(res, orient='index') # remove those with no counts rowsum = res_.sum(axis=1) print('removing due to zero counts:', ', '.join( [res_.index[i] for i, x in enumerate(res_.sum(axis=1) == 0) if x])) sleep(0.5) res_ = res_[~(res_.sum(axis=1) == 0)] # set up for calculating rarefaction tot = res_.apply(sum, axis=1) S = res_.apply(lambda x: x[x > 0].shape[0], axis=1) nr = res_.shape[0] # append the results to a dictionary rarecurve = {} for i in tqdm(range(0, nr), desc='Calculating rarefaction curve '): n = np.arange(1, tot[i], step=10) if (n[-1:] != tot[i]): n = np.append(n, tot[i]) rarecurve[res_.index[i]] = [ rarefun(np.array(res_.iloc[i, ]), z) for z in n] y = pd.DataFrame([rarecurve[c] for c in rarecurve]).T pred = pd.DataFrame([np.append(np.arange(1, s, 10), s) for s in res_.sum(axis=1)], index=res_.index).T y = y.melt() pred = pred.melt() pred['yhat'] = y['value'] options.figure_size = figsize if palette is None: if self.__class__ == AnnData: try: pal = self.uns[str(color)+'_colors'] except: if len(list(set((pred.variable)))) <= 20: pal = palettes.default_20 elif len(list(set((pred.variable)))) <= 28: pal = palettes.default_28 elif len(list(set((pred.variable)))) <= 102: pal = palettes.default_102 else: pal = None if pal is not None: p = (ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab('number of cells') + ylab('number of clones') + ggtitle('rarefaction curve') + labs(color=color) + scale_color_manual(values=(pal)) + geom_line()) else: p = (ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab('number of cells') + ylab('number of clones') + ggtitle('rarefaction curve') + labs(color=color) + geom_line()) else: if len(list(set((pred.variable)))) <= 20: pal = palettes.default_20 elif len(list(set((pred.variable)))) <= 28: pal = palettes.default_28 elif len(list(set((pred.variable)))) <= 102: pal = palettes.default_102 else: pal = None if pal is not None: p = (ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab('number of cells') + ylab('number of clones') + ggtitle('rarefaction curve') + labs(color=color) + scale_color_manual(values=(pal)) + geom_line()) else: p = (ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab('number of cells') + ylab('number of clones') + ggtitle('rarefaction curve') + labs(color=color) + geom_line()) else: p = (ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab('number of cells') + ylab('number of clones') + ggtitle('rarefaction curve') + labs(color=color) + geom_line()) if save: p.save(filename='figures/rarefaction'+str(save), height=plt.rcParams['figure.figsize'][0], width=plt.rcParams['figure.figsize'][1], units='in', dpi=plt.rcParams["savefig.dpi"]) return(p)
def random_palette(n: int) -> Sequence: # a list of 900+colours cols = list(sns.xkcd_rgb.keys()) # if max_colors_needed1 > len(cols): cols2 = list(sns.color_palette('husl', n)) palette = random.sample(sns.xkcd_palette(cols) + cols2, n) return(palette)
[docs]def clone_network(adata: AnnData, basis: str = 'bcr', edges: bool = True, **kwargs) -> Union[Figure, Axes, None]: """ Using scanpy's plotting module to plot the network. Only thing that is changed is the dfault options: `basis = 'bcr'` and `edges = True`. Parameters ---------- adata : AnnData AnnData object. basis : str key for embedding. Default is 'bcr'. edges : bool whether or not to plot edges. Default is True. **kwargs passed `sc.pl.embedding`. """ embedding(adata, basis=basis, edges=edges, **kwargs)
[docs]def barplot(self: Union[AnnData, Dandelion], color: str, palette: str = 'Set1', figsize: Tuple[Union[int, float], Union[int, float]] = (12, 4), normalize: bool = True, sort_descending: bool = True, title: Union[None, str] = None, xtick_rotation: Union[None, Union[int, float]] = None, min_clone_size: Union[None, int] = None, clone_key: Union[None, str] = None, **kwargs) -> Tuple[Figure, Axes]: """ A barplot function to plot usage of V/J genes in the data. Parameters ---------- self : Dandelion, AnnData `Dandelion` or `AnnData` object. color : str column name in metadata for plotting in bar plot. palette : str Colors to use for the different levels of the color variable. Should be something that can be interpreted by [color_palette](https://seaborn.pydata.org/generated/seaborn.color_palette.html#seaborn.color_palette), or a dictionary mapping hue levels to matplotlib colors. See [seaborn.barplot](https://seaborn.pydata.org/generated/seaborn.barplot.html). figsize : Tuple[Union[int,float], Union[int,float]] figure size. Default is (12, 4). normalize : bool if True, will return as proportion out of 1, otherwise False will return counts. Default is True. sort_descending : bool whether or not to sort the order of the plot. Default is True. title : str, optional title of plot. xtick_rotation : int, float, optional rotation of x tick labels. min_clone_size : int, optional minimum clone size to keep. Defaults to 1 if left as None. clone_key : str, optional column name for clones. None defaults to 'clone_id'. **kwargs passed to `sns.barplot`. Returns ------- a seaborn barplot. """ if self.__class__ == Dandelion: data = self.metadata.copy() elif self.__class__ == AnnData: data = self.obs.copy() if min_clone_size is None: min_size = 1 else: min_size = int(min_clone_size) if clone_key is None: clone_ = 'clone_id' else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] sns.set_style('whitegrid', {'axes.grid': False}) res = pd.DataFrame(data_[color].value_counts(normalize=normalize)) if not sort_descending: res = res.sort_index() res.reset_index(drop=False, inplace=True) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) # plot sns.barplot(x='index', y=color, data=res, palette=palette, **kwargs) # change some parts if title is None: ax.set_title(color.replace('_', ' ')+' usage') else: ax.set_title(title) if normalize: ax.set_ylabel('proportion') else: ax.set_ylabel('count') ax.set_xlabel('') if xtick_rotation is None: plt.xticks(rotation=90) else: plt.xticks(rotation=xtick_rotation) return fig, ax
[docs]def stackedbarplot(self: Union[AnnData, Dandelion], color: str, groupby: Union[None, str], figsize: Tuple[Union[int, float], Union[int, float]] = (12, 4), normalize: bool = False, title: Union[None, str] = None, sort_descending: bool = True, xtick_rotation: Union[None, Union[float, int]] = None, hide_legend: bool = True, legend_options: Tuple[str, Tuple[float, float], int] = None, labels: Union[None, Sequence] = None, min_clone_size: Union[None, int] = None, clone_key: Union[None, str] = None, **kwargs) -> Tuple[Figure, Axes]: """ A stackedbarplot function to plot usage of V/J genes in the data split by groups. Parameters ---------- self : Dandelion, AnnData `Dandelion` or `AnnData` object. color : str column name in metadata for plotting in bar plot. groupby : str column name in metadata to split by during plotting. figsize : Tuple[Union[int,float], Union[int,float]] figure size. Default is (12, 4). normalize : bool if True, will return as proportion out of 1, otherwise False will return counts. Default is True. sort_descending : bool whether or not to sort the order of the plot. Default is True. title : str, optional title of plot. xtick_rotation: Union[None, Union[float,int]] : int, float, optional rotation of x tick labels. hide_legend : bool whether or not to hide the legend. legend_options : Tuple[str, Tuple[float, float], int] a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : Sequence, optional Names of objects will be used for the legend if list of multiple dataframes supplied. min_clone_size : int, optional minimum clone size to keep. Defaults to 1 if left as None. clone_key : str, optional column name for clones. None defaults to 'clone_id'. **kwargs other kwargs passed to `matplotlib.plt`. Returns ------- stacked bar plot. """ if self.__class__ == Dandelion: data = self.metadata.copy() elif self.__class__ == AnnData: data = self.obs.copy() # quick fix to prevent dropping of nan data[groupby] = [str(l) for l in data[groupby]] if min_clone_size is None: min_size = 1 else: min_size = int(min_clone_size) if clone_key is None: clone_ = 'clone_id' else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] dat_ = pd.DataFrame(data_.groupby(color)[groupby].value_counts( normalize=normalize).unstack(fill_value=0).stack(), columns=['value']) dat_.reset_index(drop=False, inplace=True) dat_order = pd.DataFrame(data[color].value_counts(normalize=normalize)) dat_ = dat_.pivot(index=color, columns=groupby, values='value') if sort_descending is True: dat_ = dat_.reindex(dat_order.index) elif sort_descending is False: dat_ = dat_.reindex(dat_order.index[::-1]) elif sort_descending is None: dat_ = dat_.sort_index() def _plot_bar_stacked(dfall: pd.DataFrame, labels: Union[None, Sequence] = None, figsize: Tuple[Union[int, float], Union[int, float]] = (12, 4), title: str = "multiple stacked bar plot", xtick_rotation: Union[None, Union[float, int]] = None, legend_options: Tuple[str, Tuple[float, float], int] = None, hide_legend: bool = True, H: Literal["/"] = "/", **kwargs) -> Tuple[Figure, Axes]: """ Given a list of dataframes, with identical columns and index, create a clustered stacked bar plot. Parameters ---------- labels a list of the dataframe objects. Names of objects will be used for the legend. title string for the title of the plot H is the hatch used for identification of the different dataframes **kwargs other kwargs passed to matplotlib.plt """ if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot(kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs) # make bar plots h, l = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i:i+n_col]): for rect in pa.patches: # for each index rect.set_x(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col)) rect.set_hatch(H * int(i / n_col)) # edited part rect.set_width(1 / float(n_df + 1)) ax.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.) ax.set_xticklabels(df.index, rotation=0) ax.set_title(title) if normalize: ax.set_ylabel('proportion') else: ax.set_ylabel('count') # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="grey", hatch=H * i)) if legend_options is None: Legend = ('center right', (1.15, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend(h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False) ax.add_artist(l1) if xtick_rotation is None: plt.xticks(rotation=90) else: plt.xticks(rotation=xtick_rotation) return fig, ax if title is None: title = "multiple stacked bar plot : " + \ color.replace('_', ' ') + ' usage' else: title = title return _plot_bar_stacked(dat_, labels=labels, figsize=figsize, title=title, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs)
[docs]def spectratype(self: Union[AnnData, Dandelion], color: str, groupby: str, locus: str, clone_key: Union[None, str] = None, figsize: Tuple[Union[int, float], Union[int, float]] = (6, 4), width: Union[None, Union[int, float]] = None, title: Union[None, str] = None, xtick_rotation: Union[None, Union[float, int]] = None, hide_legend: bool = True, legend_options: Tuple[str, Tuple[float, float], int] = None, labels: Union[None, Sequence] = None, **kwargs) -> Tuple[Figure, Axes]: """ A spectratype function to plot usage of CDR3 length in the data split by groups. Parameters ---------- self : Dandelion, AnnData `Dandelion` or `AnnData` object. color : str column name in metadata for plotting in bar plot. groupby : str column name in metadata to split by during plotting. locus : str either IGH or IGL. figsize : Tuple[Union[int,float], Union[int,float]] figure size. Default is (6, 4). width : float, int, optional width of bars. title : str, optional title of plot. xtick_rotation : int, float, optional rotation of x tick labels. hide_legend : bool whether or not to hide the legend. legend_options : Tuple[str, Tuple[float, float], int] a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : Sequence, optional Names of objects will be used for the legend if list of multiple dataframes supplied. **kwargs other kwargs passed to matplotlib.pyplot.plot Returns ------- sectratype plot """ if clone_key is None: clonekey = 'clone_id' else: clonekey = clone_key if self.__class__ == Dandelion: data = self.data.copy() else: try: data = self.copy() except: AttributeError( "Please provide a <class 'Dandelion'> class object or a pandas dataframe instead of %s." % self.__class__) if 'locus' not in data.columns: raise AttributeError("Please ensure dataframe contains 'locus' column") if type(locus) is not list: locus = [locus] data = data[data['locus'].isin(locus)] data[groupby] = [str(l) for l in data[groupby]] dat_ = pd.DataFrame(data.groupby(color)[groupby].value_counts( normalize=False).unstack(fill_value=0).stack(), columns=['value']) dat_.reset_index(drop=False, inplace=True) dat_[color] = pd.to_numeric(dat_[color], errors='coerce') dat_.sort_values(by=color) dat_2 = dat_.pivot(index=color, columns=groupby, values='value') new_index = range(0, int(dat_[color].max())+1) dat_2 = dat_2.reindex(new_index, fill_value=0) def _plot_spectra_stacked(dfall: pd.DataFrame, labels: Union[None, Sequence] = None, figsize: Tuple[Union[int, float], Union[int, float]] = (6, 4), title: str = "multiple stacked bar plot", width: Union[None, Union[int, float]] = None, xtick_rotation: Union[None, Union[float, int]] = None, legend_options: Tuple[str, Tuple[float, float], int] = None, hide_legend: bool = True, H: Literal["/"] = "/", **kwargs) -> Tuple[Figure, Axes]: if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) if width is None: wdth = 0.1 * n_ind/60+0.8 else: wdth = width # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot(kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs) # make bar plots h, l = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i:i+n_col]): for rect in pa.patches: # for each index rect.set_x(rect.get_x() + 1 / float(n_df + 1) * i / float(n_col)) rect.set_hatch(H * int(i / n_col)) # edited part # need to see if there's a better way to toggle this. rect.set_width(wdth) n = 5 # Keeps every 5th label visible and hides the rest [l.set_visible(False) for (i, l) in enumerate( ax.xaxis.get_ticklabels()) if i % n != 0] ax.set_title(title) ax.set_ylabel('count') # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="gray", hatch=H * i)) if legend_options is None: Legend = ('center right', (1.25, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend(h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False) ax.add_artist(l1) if xtick_rotation is None: plt.xticks(rotation=0) else: plt.xticks(rotation=xtick_rotation) return fig, ax return _plot_spectra_stacked(dat_2, labels=labels, figsize=figsize, title=title, width=width, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs)
[docs]def clone_overlap(self: Union[AnnData, Dandelion], groupby: str, colorby: str, min_clone_size: Union[None, int] = None, clone_key: Union[None, str] = None, color_mapping: Union[None, Sequence, Dict] = None, node_labels: bool = True, node_label_layout: Literal[None, 'rotation', 'numbers'] = 'rotation', group_label_position: Literal['beginning', 'middle', 'end'] = 'middle', group_label_offset: int = 8, figsize: Tuple[Union[int, float], Union[int, float]] = (8, 8), return_graph: bool = False, save: Union[None, str] = None, **kwargs): """ A plot function to visualise clonal overlap as a circos-style plot. Requires nxviz. Parameters ---------- self : Dandelion, AnnData `Dandelion` or `AnnData` object. groupby : str column name in obs/metadata for collapsing to nodes in circos plot. colorby : str column name in obs/metadata for grouping and color of nodes in circos plot. min_clone_size : int, optional minimum size of clone for plotting connections. Defaults to 2 if left as None. clone_key : str, optional column name for clones. None defaults to 'clone_id'. color_maopping : Dict, Sequence, optional custom color mapping provided as a sequence (correpsonding to order of categories or alpha-numeric order if dtype is not category), or dictionary containing custom {category:color} mapping. node_labels : bool, optional whether to use node objects as labels or not node_label_layout : bool, optional which/whether (a) node layout is used. One of 'rotation', 'numbers' or None. group_label_position : str The position of the group label. One of 'beginning', 'middle' or 'end'. group_label_offset : int, float how much to offset the group labels, so that they are not overlapping with node labels. figsize : Tuple[Union[int,float], Union[int,float]] figure size. Default is (8, 8). return_graph : bool whether or not to return the graph for fine tuning. Default is False. **kwargs passed to `matplotlib.pyplot.savefig`. Returns ------- a `nxviz.CircosPlot`. """ import networkx as nx try: import nxviz as nxv except: raise(ImportError("Unable to import module `nxviz`. Have you done install nxviz? Try pip install git+https://github.com/zktuong/nxviz.git")) if min_clone_size is None: min_size = 2 else: min_size = int(min_clone_size) if clone_key is None: clone_ = 'clone_id' else: clone_ = clone_key if self.__class__ == AnnData: data = self.obs.copy() # get rid of problematic rows that appear because of category conversion? data = data[~(data[clone_].isin( [np.nan, 'nan', 'NaN', 'No_BCR', 'unassigned', None]))] if 'clone_overlap' in self.uns: overlap = self.uns['clone_overlap'].copy() else: # prepare a summary table datc_ = data[clone_].str.split('|', expand=True).stack() datc_ = pd.DataFrame(datc_) datc_.reset_index(drop=False, inplace=True) datc_.columns = ['cell_id', 'tmp', clone_] datc_.drop('tmp', inplace=True, axis=1) datc_ = datc_[~(datc_[clone_].isin( ['', np.nan, 'nan', 'NaN', 'No_BCR', 'unassigned', None]))] dictg_ = dict(data[groupby]) datc_[groupby] = [dictg_[l] for l in datc_['cell_id']] overlap = pd.crosstab(data[clone_], data[groupby]) if min_size == 0: raise ValueError('min_size must be greater than 0.') elif min_size > 2: overlap[overlap < min_size] = 0 overlap[overlap >= min_size] = 1 elif min_size == 2: overlap[overlap >= min_size] = 1 overlap.index.name = None overlap.columns.name = None elif self.__class__ == Dandelion: data = self.metadata.copy() # get rid of problematic rows that appear because of category conversion? data = data[~(data[clone_].isin( [np.nan, 'nan', 'NaN', 'No_BCR', 'unassigned', None]))] # prepare a summary table datc_ = data[clone_].str.split('|', expand=True).stack() datc_ = pd.DataFrame(datc_) datc_.reset_index(drop=False, inplace=True) datc_.columns = ['cell_id', 'tmp', clone_] datc_.drop('tmp', inplace=True, axis=1) dictg_ = dict(data[groupby]) datc_[groupby] = [dictg_[l] for l in datc_['cell_id']] overlap = pd.crosstab(data[clone_], data[groupby]) if min_size == 0: raise ValueError('min_size must be greater than 0.') elif min_size > 2: overlap[overlap < min_size] = 0 overlap[overlap >= min_size] = 1 elif min_size == 2: overlap[overlap >= min_size] = 1 overlap.index.name = None overlap.columns.name = None edges = {} for x in overlap.index: if overlap.loc[x].sum() > 1: edges[x] = [y + ({str(clone_): x},) for y in list(combinations( [i for i in overlap.loc[x][overlap.loc[x] == 1].index], 2))] # create graph G = nx.Graph() # add in the nodes G.add_nodes_from([(p, {str(colorby): d}) for p, d in zip(data[groupby], data[colorby])]) # unpack the edgelist and add to the graph for edge in edges: G.add_edges_from(edges[edge]) groupby_dict = dict(zip(data[groupby], data[colorby])) if color_mapping is None: if self.__class__ == AnnData: if pd.api.types.is_categorical_dtype(self.obs[groupby]): try: colorby_dict = dict(zip( list(self.obs[str(colorby)].cat.categories), self.uns[str(colorby)+'_colors'])) except: pass else: if type(color_mapping) is dict: colorby_dict = color_mapping else: if pd.api.types.is_categorical_dtype(data[groupby]): colorby_dict = dict( zip(list(data[str(colorby)].cat.categories), color_mapping)) else: colorby_dict = dict( zip(sorted(list(set(data[str(colorby)]))), color_mapping)) df = data[[groupby, colorby]] if groupby == colorby: df = data[[groupby]] df = df.sort_values(groupby).drop_duplicates( subset=groupby, keep="first").reset_index(drop=True) else: df = df.sort_values(colorby).drop_duplicates( subset=groupby, keep="first").reset_index(drop=True) c = nxv.CircosPlot(G, node_color=colorby, node_grouping=colorby, node_labels=node_labels, node_label_layout=node_label_layout, group_label_position=group_label_position, group_label_offset=group_label_offset, figsize=figsize) c.nodes = list(df[groupby]) if 'colorby_dict' in locals(): c.node_colors = [colorby_dict[groupby_dict[c]] for c in c.nodes] c.compute_group_label_positions() c.compute_group_colors() c.draw() if save is not None: plt.savefig(save, bbox_inches='tight', **kwargs) if return_graph: return(c)