Source code for adataviz.plotting

from re import L
import os, sys
import pandas as pd
import anndata
import scanpy as sc
import matplotlib.pylab as plt
import numpy as np
from matplotlib.colors import Normalize
import seaborn as sns
from .utils import (
    normalize_mc_by_cell,categorical_scatter,continuous_scatter
)
from .tools import load_adata,load_obs,load_color_palette
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from loguru import logger as logger
logger.add(sys.stderr, level="DEBUG")
# logger.add(sys.stderr, level="ERROR")

[docs] def use_scientific_style(): import matplotlib.pylab as plt plt.rcParams.update({ 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.family': 'Arial', # Base font size (set so text is ~7–8 pt at final print size) 'font.size': 8, # main text (labels, ticks, legend) 'axes.labelsize': 9, # axis labels (x/y) 'axes.titlesize': 10, # panel titles / figure titles 'figure.titlesize':11, 'xtick.labelsize': 8, # x-tick labels 'ytick.labelsize': 8, # y-tick labels 'legend.fontsize': 8, # legend text 'legend.title_fontsize': 9, # legend title (if used) # Lines and elements for clarity 'lines.linewidth': 1.2, # data lines 'axes.linewidth': 1.0, # axis spines 'xtick.major.width': 1.0, 'ytick.major.width': 1.0, 'xtick.major.size': 4, 'ytick.major.size': 4, # General figure appearance 'figure.dpi': 300, # high resolution for export 'savefig.dpi': 300, # when using plt.savefig() 'figure.figsize': (6.5, 4.5), # example starting size (adjust to your needs; e.g., ~17 cm wide for full page) # 'figure.constrained_layout.use': True, 'figure.autolayout': True, 'savefig.transparent': True, 'savefig.bbox': 'tight', 'savefig.pad_inches': 0, 'pdf.fonttype':42, 'ps.fonttype':42, })
# plt.rcParams.keys()
[docs] def interactive_embedding( adata=None,obs=None,variable=None,gene=None, coord="umap",vmin='p1',vmax='p99',cmap='jet',title=None, width=900,height=750,colors=None,palette_path=None, size=None,show=True,downsample=None,target_fill=0.05, normalize_per_cell=True,clip_norm_value=10, renderer="notebook"): """ Plot interactive embedding plot with plotly for a given AnnData object or path of .h5ad. Parameters ---------- adata : _type_ _description_ obs : _type_, optional _description_, by default None variable : _type_, optional _description_, by default None gene : _type_, optional _description_, by default None coord : str, optional _description_, by default "umap" vmin : str, optional _description_, by default 'p1' vmax : str, optional _description_, by default 'p99' cmap : str, optional _description_, by default 'jet' title : _type_, optional _description_, by default None width : int, optional _description_, by default 1000 height : int, optional _description_, by default 800 colors : _type_, optional _description_, by default None palette_path : _type_, optional _description_, by default None size : _type_, optional _description_, by default None target_fill : float, optional _description_, by default 0.05 show : bool, optional _description_, by default True renderer : str, optional _description_, by default "notebook" Available renderers: ['plotly_mimetype', 'jupyterlab', 'nteract', 'vscode', 'notebook', 'notebook_connected', 'kaggle', 'azure', 'colab', 'cocalc', 'databricks', 'json', 'png', 'jpeg', 'jpg', 'svg', 'pdf', 'browser', 'firefox', 'chrome', 'chromium', 'iframe', 'iframe_connected', 'sphinx_gallery', 'sphinx_gallery_png'] Returns ------- _type_ _description_ """ if not renderer is None: pio.renderers.default = renderer use_col=gene if not gene is None else variable if not gene is None: assert not adata is None, "`gene` provided, `adata` must be provided too." if not adata is None: adata=load_adata(adata) use_adata=None if not gene is None: # adata is not None if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if normalize_per_cell: use_adata = normalize_mc_by_cell( use_adata=use_adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=False) else: if not adata is None: use_adata=adata # backed if obs is None and use_adata is None: raise ValueError("Either `adata` or `obs` must be provided.") if obs is None: obs=use_adata.obs.copy() # type: ignore else: # obs not none obs=load_obs(obs) if not use_adata is None: overlap_idx=obs.index.intersection(use_adata.obs_names) obs=obs.loc[overlap_idx] use_adata=use_adata[overlap_idx,:] # type: ignore if not gene is None: obs[gene]=use_adata.to_df()[gene].tolist() # type: ignore # cols=set(obs.columns.tolist()) if f'X_{coord}' in use_adata.obsm: # type: ignore # print(use_adata.obsm[f'X_{coord}']) obs[f'{coord}_0']=use_adata.obsm[f'X_{coord}'][:,0].tolist() # type: ignore obs[f'{coord}_1']=use_adata.obsm[f'X_{coord}'][:,1].tolist() # type: ignore # print(obs.head()) if not adata is None and adata.isbacked: # type: ignore adata.file.close() # type: ignore # downsample obs for large dataset n_points = obs.shape[0] if not downsample is None and n_points > downsample: sample_idx = np.random.choice(n_points, size=downsample, replace=False) # numbers obs = obs.iloc[sample_idx] if not obs.dtypes[use_col] in ['object','category']: vmin_quantile=float(int(vmin.replace('p','')) / 100) vmax_quantile=float(int(vmax.replace('p','')) / 100) # print(vmin_quantile,vmax_quantile,obs[use_col],obs.dtypes[use_col]) range_color=[obs[use_col].quantile(vmin_quantile), obs[use_col].quantile(vmax_quantile)] color_discrete_map=None else: if colors is None: # color_discrete_map=get_colors(use_adata,use_col,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=use_adata,groups=use_col) else: color_discrete_map=colors if not color_discrete_map is None: keys=list(color_discrete_map.keys()) # type: ignore for k in keys: if k not in obs[use_col].unique().tolist(): del color_discrete_map[k] # type: ignore range_color=None keep_cols=['cell',f'{coord}_0',f'{coord}_1'] if not variable is None: keep_cols.append(variable) if not gene is None: keep_cols.append(gene) obs=obs.reset_index(names="cell").loc[:,keep_cols] # Create Plotly interactive scatter plot hover_data={ # Fields to show on hover "cell": True, # cell ID f'{coord}_0': ":0.3f",# UMAP coordinates rounded to 3 decimals f'{coord}_1': ":0.3f", } if not variable is None: hover_data[variable]=True # type: ignore # when plotting gene expression, also show cell types when mouse hover if not gene is None: hover_data[gene]=":.3f" # type: ignore fig = px.scatter( obs, x=f'{coord}_0', # UMAP first dimension → X axis y=f'{coord}_1', # UMAP second dimension → Y axis color=use_col, hover_data=hover_data, range_color=range_color, color_discrete_sequence=px.colors.qualitative.D3, # color palette (professional, unobtrusive) color_discrete_map=color_discrete_map, color_continuous_scale=cmap, #["blue", "red"], template="plotly_white", render_mode='webgl' # use WebGL rendering for better performance with large datasets ) fig.update_xaxes(range=[obs[f'{coord}_0'].min()-0.5, obs[f'{coord}_0'].max()+0.5],tickfont_size=12) fig.update_yaxes(range=[obs[f'{coord}_1'].min()-0.5, obs[f'{coord}_1'].max()+0.5],tickfont_size=12) if size is None: # Blend an area-based marker estimate with a log-based fallback so total point count and canvas size both matter. # Increased target_fill and scaling to make markers bigger marker_diam_area = 2 * np.sqrt((width * height * target_fill) / (np.pi * n_points)) marker_diam_log = 16 - 2 * np.log10(n_points) marker_diam = 0.7 * marker_diam_area + 0.5 * marker_diam_log size = int(np.clip(marker_diam, 4, 20)) if n_points < 500000: opacity = 0.8 else: opacity = 0.6 # logger.debug(f"{variable},{gene},{use_col}") # print(color_discrete_map,size,opacity) fig.update_traces( marker=dict(size=size, opacity=opacity, line=dict(width=0.12, color='black')), selector=dict(mode='markers') ) if title is None: title = f"{coord.upper()} Visualization (Colored by {use_col})" fig.update_layout( title=dict( text=title, font_size=16, x=0.5, # center the title pad=dict(t=10) ), xaxis_title=f'{coord}_0'.upper(), yaxis_title=f'{coord}_1'.upper(), autosize=True,width=width,height=height, legend_title=use_col, # legend title legend=dict( font_size=12, itemsizing='constant', # important: fix legend marker size so it's not affected by scatter points itemwidth=30, borderwidth=0.1 # legend item width; larger value increases the marker size ) ) if show: filename=f"{coord}.{use_col}" show_fig(fig,filename=filename) else: return fig
# html=fig2div(fig,filename='umap_plot') # return HttpResponse(html)
[docs] def show_fig(fig,filename="plot"): interactive_config={ 'displayModeBar':'hover','showLink':False,'linkText':'Edit on plotly', 'scrollZoom':True,"displaylogo": False, 'toImageButtonOptions':{'format':'svg','filename':filename}, 'modeBarButtonsToRemove':['sendDataToCloud'], # 'zoomIn2d','zoomOut2d','zoom2d','zoom3d','pan2d' 'editable':True,'autosizable':True,#'responsive':True, 'fillFrame':True, 'edits':{ 'titleText':True,'legendPosition':True,'colorbarTitleText':True, 'shapePosition':True,'annotationPosition':True,'annotationText':True, 'axisTitleText':True,'legendText':True,'colorbarPosition':True} } fig.show(config=interactive_config)
[docs] def plot_categorical( adata,ax=None,basis='umap',groupby='MajorType', coding=True,coded_marker=True, save=None,palette_path=None,sheet_name=None, show=True,figsize=(4, 3.5), ncol=None,fontsize=5,legend_fontsize=5, legend_kws=None,legend_title_fontsize=5, marker_fontsize=4,marker_pad=0.1, linewidth=0.5,axis_format='tiny',alpha=0.7, text_kws=None,**kwargs): from pandas.api.types import is_categorical_dtype if basis.startswith("X_"): basis=basis.replace('X_','') if sheet_name is None: sheet_name=groupby adata=load_adata(adata) if not is_categorical_dtype(adata.obs[groupby]): adata.obs[groupby] = adata.obs[groupby].astype('category') colors=None if not palette_path is None: if isinstance(palette_path,str): try: colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict() keys=list(colors.keys()) existed_vals=adata.obs[groupby].unique().tolist() for k in existed_vals: if k not in keys: colors[k]='gray' for k in keys: if k not in existed_vals: del colors[k] except: colors=None if colors is None: if f'{groupby}_colors' in adata.uns: colors={cluster:color for cluster,color in zip(adata.obs[groupby].cat.categories.tolist(),adata.uns[f'{groupby}_colors'])} else: sc.pl.embedding(adata,basis=f"X_{basis}",color=[groupby],show=False) colors={cluster:color for cluster,color in zip(adata.obs[groupby].cat.categories.tolist(),adata.uns[f'{groupby}_colors'])} else: adata.uns[groupby + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[groupby].cat.categories.tolist()] hue=groupby text_anno = groupby text_kws = {} if text_kws is None else text_kws text_kws.setdefault("fontsize", fontsize) kwargs.setdefault("hue",hue) kwargs.setdefault("text_anno", text_anno) kwargs.setdefault("text_kws", text_kws) kwargs.setdefault("luminance", 0.65) kwargs.setdefault("dodge_text", False) kwargs.setdefault("axis_format", axis_format) kwargs.setdefault("show_legend", True) kwargs.setdefault("marker_fontsize", marker_fontsize) kwargs.setdefault("marker_pad", marker_pad) kwargs.setdefault("linewidth", linewidth) kwargs.setdefault("alpha", alpha) kwargs["coding"]=coding kwargs["coded_marker"]=coded_marker legend_kws={} if legend_kws is None else legend_kws default_lgd_kws=dict( fontsize=legend_fontsize, title=groupby,title_fontsize=legend_title_fontsize) if not ncol is None: default_lgd_kws['ncol']=ncol for k in default_lgd_kws: legend_kws.setdefault(k, default_lgd_kws[k]) kwargs.setdefault("dodge_kws", { "arrowprops": { "arrowstyle": "->", "fc": 'grey', "ec": "none", "connectionstyle": "angle,angleA=-90,angleB=180,rad=5", }, 'autoalign': 'xy'}) if ax is None: fig, ax = plt.subplots(figsize=figsize, dpi=300) p = categorical_scatter( data=adata[adata.obs[groupby].notna(),], ax=ax, basis=basis, palette=colors,legend_kws=legend_kws, **kwargs) if not save is None and save!=False: plt.savefig(os.path.expanduser(save),bbox_inches='tight') # ,,dpi=300 if show: plt.show()
[docs] def plot_gene( adata,obs=None,groupby=None,gene='CADM1',query_str=None, title=None,palette_path=None,hue_norm=None, cbar_kws=dict(extendfrac=0.1),axis_format="tiny",scatter_kws={}, obsm=None,basis='umap',normalize_per_cell=True, stripplot=False,hypo_score=False,ylim=None, clip_norm_value=10,min_cells=3,cmap='parula', prefix=None): # sc.set_figure_params(dpi=100,dpi_save=300,frameon=False) if title is None: if not query_str is None: title=query_str else: title=groupby if not groupby is None else gene raw_adata = anndata.read_h5ad(os.path.expanduser(adata), backed='r') adata = raw_adata[:, gene].to_memory() raw_adata.file.close() # close the file to save memory if normalize_per_cell: adata = normalize_mc_by_cell( use_adata=adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) is_open=False if not obsm is None: if isinstance(obsm, str): obsm = anndata.read_h5ad(os.path.expanduser(obsm),backed='r') is_open=True assert isinstance(obsm, anndata.AnnData), "obsm should be an anndata object or a path to an h5ad file." keep_cells = list(set(adata.obs_names.tolist()) & set(obsm.obs_names.tolist())) adata = adata[keep_cells, :] adata.obsm = obsm[keep_cells].obsm cur_cols = adata.obs.columns.tolist() for col in obsm.obs.columns.tolist(): if col not in cur_cols: adata.obs[col] = obsm.obs.loc[adata.obs_names, col].tolist() if is_open: obsm.file.close() if not obs is None: if isinstance(obs,str): obs=pd.read_csv(os.path.expanduser(obs), sep='\t',index_col=0) else: obs=obs.copy() else: obs=adata.obs.copy() if not query_str is None: obs = obs.query(query_str) overlapped_cells=list(set(adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] adata=adata[overlapped_cells,:] # type: ignore adata.obs=obs.loc[adata.obs_names.tolist()] print(adata.shape) # read color palette if not groupby is None and not palette_path is None: if os.path.exists(os.path.expanduser(palette_path)): palette_path = os.path.abspath(os.path.expanduser(palette_path)) D = pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette = D[groupby].Hex.to_dict() else: color_palette = adata.obs.reset_index().loc[:, [groupby, \ palette_path]].drop_duplicates().dropna().set_index(groupby)[ palette_path].to_dict() else: color_palette = None # plot gene on given cordinate base # fig, ax = plt.subplots(figsize=(4, 4), dpi=300) # output=os.path.join(figdir, f"{title}.{gene}.{basis}.pdf") # sc.pl.embedding(adata, basis=basis, # wspace=0.1, color=[gene],use_raw=False, # ncols=2, vmin='p5', vmax='p95', frameon=False, # show=False,cmap=cmap,ax=ax) # colorbar = fig.axes[-1] # cur_pos=colorbar.get_position() # colorbar.set_position([cur_pos.x0,(1-cur_pos.height/2)/2,cur_pos.width, cur_pos.height / 2]) # fig.savefig(output) # transparent=True,bbox_inches='tight',dpi=300 if prefix is None: prefix=f"{title}.{gene}.{groupby}" adata.obs[gene]=adata.to_df().loc[adata.obs_names.tolist(), gene].tolist() # print(hue_norm) fig, ax = plt.subplots(figsize=(4, 4), dpi=300) continuous_scatter( data=adata, ax=ax,cmap=cmap, hue_norm=hue_norm, cbar_kws=cbar_kws, hue=gene,axis_format=axis_format, text_anno=None, basis=basis,**scatter_kws) fig.savefig(f"{prefix}.{basis}.pdf",bbox_inches='tight') # transparent=True,bbox_inches='tight',dpi=300 if not groupby is None: # boxplot data = adata.to_df() data[groupby] = adata.obs.loc[data.index.tolist(), groupby].tolist() vc = data[groupby].value_counts() N=vc.shape[0] if not color_palette is None: keep_groups = list(set(list(color_palette.keys())) & set(vc[vc >= min_cells].index.tolist())) data = data.loc[data[groupby].isin(keep_groups)] vc = vc.to_dict() order = data.groupby(groupby)[gene].median().sort_values().index.tolist() width = max(5, N*0.5) plt.figure(figsize=(width, 3.5)) if stripplot: ax = sns.stripplot(data=data, jitter=0.4, edgecolor='white', x=groupby, y=gene, palette=color_palette, \ order=order, size=0.5) else: ax = None # ax = sns.boxplot(data=data, x=groupby, y=gene, palette=color_palette, ax=ax, # hue=groupby, # fliersize=0.5, notch=False, showfliers=False, saturation=0.6, order=order) # boxplot are incorrect for some cases when there are many 0, median and lower quartile are often at zero; use violinplot in stead. ax = sns.violinplot(data=data, x=groupby, y=gene, palette=color_palette, ax=ax, # hue=groupby, saturation=0.6, order=order,density_norm='width',cut=0,bw_adjust=0.5) # ax=sns.swarmplot(data=data,palette=color_palette,\ # edgecolor='white',x=groupby,y=gene,\ # order=order) if not ylim is None: ax.set_ybound(ylim) ax.set_xticklabels([f"{label} ({vc[label]})" for label in order]) title=title.replace(' ','.') ax.set_title(title) ax.xaxis.label.set_visible(False) plt.setp(ax.xaxis.get_majorticklabels(), rotation=-45, ha='left') plt.savefig(f"{prefix}.boxplot.pdf",bbox_inches='tight') return adata
[docs] def stacked_barplot( obs="cell_obs_with_annotation.csv",groupby='Age', column='CellClass',x_order=None,y_order=None,linewidth=0.1, palette_path=None,width=None,height=None, xticklabels_kws=None,save=False, lgd_kws=None,gap=0.05,sort_by=None): """ Plot stacked barplto to show the cell type composition in each `groupby` ( such as Age, brain regions and so on.) For example: stacked_barplot(column='MajorType',width=3.5,height=6) stacked_barplot(column='CellClass',width=3.5,height=3) """ if isinstance(obs,pd.DataFrame): data=obs.copy() elif isinstance(obs, str) and obs.endswith('.h5ad'): obs_path = os.path.abspath(os.path.expanduser(obs)) adata = anndata.read_h5ad(obs_path,backed='r') data=adata.obs del adata elif obs.endswith('.csv'): obs_path = os.path.abspath(os.path.expanduser(obs)) data=pd.read_csv(obs_path,index_col=0) else: obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0,sep='\t') xticklabels_kws={} if xticklabels_kws is None else xticklabels_kws xticklabels_kws.setdefault('rotation',-45) xticklabels_kws.setdefault("rotation_mode", "anchor") xticklabels_kws.setdefault('horizontalalignment', 'left') #see ?matplotlib.axes.Axes.tick_params xticklabels_kws.setdefault('verticalalignment', 'center') if not palette_path is None: if isinstance(palette_path,dict): color_palette=palette_path.copy() elif isinstance(palette_path,str) and os.path.exists(os.path.expanduser(palette_path)): palette_path=os.path.abspath(os.path.expanduser(palette_path)) D=pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette=D[column].Hex.to_dict() keys=list(color_palette.keys()) for k in data[column].unique(): if k not in keys: color_palette[k]='gray' else: color_palette = palette_path else: color_palette=None df=data.groupby(groupby)[column].value_counts(normalize=True).unstack(level=column) if not sort_by is None: df.sort_values(sort_by,ascending=True,inplace=True) else: if x_order is None: x_order = sorted(df.index.tolist()) if y_order is None: y_order = sorted(df.columns.tolist()) df=df.loc[x_order,y_order] if width is None: width=max(df.shape[0]*0.45,10) if width < 2.5: width=2.5 if height is None: height = max(df.shape[1]*0.5, 8) if height < 3.5: height = 3.5 plt.figure() ax=df.plot.bar(stacked=True,align='edge', width=1-gap,edgecolor='black', linewidth=linewidth,figsize=(width,height), color=color_palette) ax.set_xlim(0,df.shape[0]) ax.set_ylim(0,1) labels=[tick.get_text() for tick in ax.get_xticklabels()] ax.set_xticks(ticks=np.arange(0.5,df.shape[0],1), labels=labels,**xticklabels_kws) # ax.xaxis.set_major_locator(ticker.FixedLocator(np.arange(0.5,df.shape[1],1))) #ticker.MultipleLocator(0.5) ax.xaxis.label.set_visible(False) ax.tick_params( axis="y", #both which="both",left=False,right=False,labelleft=False,labelright=False, top=False,labeltop=False,#bottom=False,labelbottom=False ) # ax.xaxis.set_tick_params(axis='x') lgd_kws = lgd_kws if not lgd_kws is None else {} # bbox_to_anchor=(x,-0.05) lgd_kws.setdefault("frameon", True) lgd_kws.setdefault("ncol", 1) lgd_kws["loc"] = "upper left" lgd_kws.setdefault("borderpad", 0.1 * (1 / 25.4) * 72) # 0.1mm lgd_kws.setdefault("markerscale", 1) lgd_kws.setdefault("handleheight", 1) # font size, units is points lgd_kws.setdefault("handlelength", 1) # font size, units is points lgd_kws.setdefault("borderaxespad", 0.1) # The pad between the axes and legend border, in font-size units. lgd_kws.setdefault("handletextpad", 0.3) # The pad between the legend handle and text, in font-size units. lgd_kws.setdefault("labelspacing", 0.1) # gap height between two Patches, 0.05*mm2inch*72 lgd_kws.setdefault("columnspacing", 1) lgd_kws.setdefault("bbox_to_anchor", (1, 1)) lgd_kws.setdefault("title",column) ax.legend(**lgd_kws) ax.grid(False) if save: outdir=os.path.dirname(os.path.expanduser(save)) if not os.path.exists(outdir): os.mkdir(outdir) plt.savefig(save,bbox_inches='tight') # transparent=True,bbox_inches='tight',dpi=300 else: plt.show() return ax
[docs] def pieplot(obs,groupby='Age',palette_path=None,order=None, save=None,explode=0.05): # colors=None if isinstance(obs,pd.DataFrame): data=obs.copy() elif isinstance(obs, str) and obs.endswith('.h5ad'): obs_path = os.path.abspath(os.path.expanduser(obs)) print(f"Reading adata: {obs}") adata = anndata.read_h5ad(obs_path, backed='r') # if f'{groupby}_colors' in adata.uns: # colors={k:v for k,v in zip(adata.obs[groupby].cat.categories.tolist(), # adata.uns[f'{groupby}_colors'])} # else: # colors=None data = adata.obs del adata elif obs.endswith('.csv'): obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0) else: obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0, sep='\t') if not palette_path is None: palette_path=os.path.abspath(os.path.expanduser(palette_path)) D=pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette=D[groupby].Hex.to_dict() else: color_palette=None D=data[groupby].value_counts() if order is None: order=list(sorted(D.keys())) plt.figure() plt.pie([D[k] for k in order], labels=order, colors=[color_palette[k] for k in order], explode=[explode]*len(order), autopct='%.1f%%') # Add title to the chart plt.title('Distribution of #of cells across different stages') if not save is None: output=os.path.abspath(os.path.expanduser(save)) else: output=f'{groupby}.piechart.pdf' plt.savefig(output,bbox_inches='tight') # transparent=True,bbox_inches='tight',dpi=300 plt.show()
[docs] def plot_pseudotime( pseudotime="dpt_pseudotime.tsv",groupby='Age',y='dpt_pseudotime', hue=None,figsize=(5,3.5),save=None,rotate=None,ylabel='Pseudotime', palette_path=None, ): """ Plot pseudotime. plot_pseudotime(figsize=(6,3.5),groupby='MajorType', rotate=-45); plot_pseudotime(figsize=(3.5,3),groupby='CellClass') plot_pseudotime(figsize=(3.5,3),groupby='Age') Parameters ---------- pseudotime : groupby : y : hue : figsize : outdir : rotate : palette_path : Returns ------- """ if not palette_path is None: palette_path=os.path.abspath(os.path.expanduser(palette_path)) D=pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette=D[groupby].Hex.to_dict() else: color_palette=None data=pd.read_csv(os.path.expanduser(pseudotime),sep='\t',index_col=0) data.dpt_pseudotime.replace({np.inf: 1},inplace=True) order=data.groupby(groupby)[y].mean().sort_values(ascending=True).index.tolist() if not hue is None: hue_order=data.groupby(hue)[y].mean().sort_values(ascending=True).index.tolist() else: hue_order=None plt.figure(figsize=figsize) # ax = sns.swarmplot(data=data, palette=color_palette, \ # edgecolor='white', x=groupby, y=y, \ # order=order) ax=sns.violinplot(data=data, x=groupby, y=y, scale='count', bw=.2, inner=None, saturation=0.6, palette=color_palette, order=order,hue=hue,hue_order=hue_order) # plt.legend(frameon=True) ax.set_ylabel(ylabel) if not rotate is None: plt.setp(ax.xaxis.get_majorticklabels(), rotation=rotate, rotation_mode="anchor",horizontalalignment='left') if save is None: outname=groupby + '.pseudotime_violin.pdf' if hue is None else groupby + f'_{hue}.pseudotime_violin.pdf' else: outname=os.path.abspath(os.path.expanduser(save)) plt.savefig(outname,bbox_inches='tight') # transparent=True,bbox_inches='tight',dpi=300 plt.show()
[docs] def stacked_violinplot(adata, use_genes=None, groupby='Age', cell_groups=None, parent=None, figsize=(6, 4), cmap='viridis'): import scanpy as sc ax = sc.pl.stacked_violin( adata[adata.obs[cell_groups[0]]==parent], var_names=use_genes, title=use_key, colorbar_title="Avg mc frac", groupby=groupby, dendrogram=True, swap_axes=False, cmap=cmap, figsize=figsize, scale='count', standard_scale='obs', inner='quart', # stripplot=False,jitter=False, show=False, layer=None) ax1 = ax['mainplot_ax'] ax1.yaxis.set_minor_locator(ticker.MultipleLocator(1)) ax1.yaxis.set_tick_params(which='minor',left=True) ax1.grid(axis='y', linestyle='--', color='black', alpha=1, zorder=-5, which='minor') # plt.savefig(f"{fig_basename}.{groupby}.stacked_violin.pdf") plt.show()
[docs] def plot_genes( adata="/home/x-wding2/Projects/BICAN/adata/HMBA_v2/HMBA.Group.downsample_1500.h5ad", query_str=None, obs=None, #"~/Projects/BG/clustering/100kb/annotations.tsv", groupby='Subclass', parent_col=None, modality='RNA', # mc or RNA use_raw=True, # True for RNA expression_cutoff='p5', # for RNA, could be int, median, mean of p5, p95 and so on genes=None, cell_type_order=None, gene_order=None, row_cluster=False, col_cluster=False, cmap='Greens_r', group_legend=False, parent_legend=False, title=None, palette_path=None,#"/home/x-wding2/Projects/BICAN/adata/HMBA_v2/HMBA_color_palette.xlsx" legend_kws=dict(extendfrac=0.1,extend='both',label='Mean mCG'), normalize_per_cell=True, clip_norm_value=10, hypo_score=False, figsize=(10, 3.5), marker='o', plot_kws={},transpose=False, outname="test.pdf"): from PyComplexHeatmap import HeatmapAnnotation,anno_label,anno_simple,DotClustermapPlotter assert not genes is None, "Please provide genes to plot." # adata could be single cell level or pseudobulk level (adata.layers['frac'] should be existed) raw_adata = anndata.read_h5ad(os.path.expanduser(adata), backed='r') all_vars=set(raw_adata.var_names.tolist()) keep_genes=list(set(all_vars) & set(genes)) # keep_genes=[g for g in all_vars if g in genes] error_genes=[g for g in genes if g not in keep_genes] if len(error_genes)>0: print(f"genes not found in adata: {error_genes}") adata = raw_adata[:, keep_genes].to_memory() # type: ignore if use_raw and not adata.raw is None: adata_raw=adata.raw[:,adata.var_names.tolist()].to_adata() adata.X=adata_raw[adata.obs_names.tolist(),adata.var_names.tolist()].X.copy() # type: ignore del adata_raw raw_adata.file.close() # close the file to save memory if not obs is None: if isinstance(obs,str): obs=pd.read_csv(os.path.expanduser(obs), sep='\t',index_col=0) else: obs=obs.copy() else: obs=adata.obs.copy() if not query_str is None: obs = obs.query(query_str) overlapped_cells=list(set(adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] adata=adata[overlapped_cells,:] # type: ignore if isinstance(groupby,list): groupby1="+".join(groupby) obs[groupby1]=obs.loc[:,groupby].apply(lambda x:'+'.join(x.astype(str).tolist()),axis=1) groupby=groupby1 adata.obs[groupby]=obs.loc[adata.obs_names.tolist(),groupby].tolist() if title is None: if not query_str is None: title=query_str else: title=groupby if not groupby is None else '-'.join(genes) if not parent_col is None and parent_col not in adata.obs.columns.tolist(): adata.obs[parent_col]=obs.loc[adata.obs_names.tolist(),parent_col].tolist() if modality not in ['RNA','ATAC'] and normalize_per_cell: adata = normalize_mc_by_cell( use_adata=adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) print(adata.shape) # read color palette color_palette={} if not palette_path is None: if os.path.exists(os.path.expanduser(palette_path)): palette_path = os.path.abspath(os.path.expanduser(palette_path)) D = pd.read_excel(palette_path, sheet_name=None, index_col=0) if groupby in D: color_palette[groupby] = D[groupby].Hex.to_dict() else: assert '+' in groupby, f"{groupby} not found in the palette file." for group in groupby.split('+'): assert group in D, f"{group} not found in the palette file." color_palette[group] = D[group].Hex.to_dict() if not parent_col is None: color_palette[parent_col] = D[parent_col].Hex.to_dict() else: color_palette[groupby] = adata.obs.reset_index().loc[:, [groupby, \ palette_path]].drop_duplicates().dropna().set_index(groupby)[ palette_path].to_dict() color_palette[parent_col] = adata.obs.reset_index().loc[:, [parent_col, \ palette_path]].drop_duplicates().dropna().set_index(parent_col)[ palette_path].to_dict() else: color_palette = None data=adata.to_df() # rows are cells or cell types, columns are genes if modality in ['RNA','ATAC'] and isinstance(expression_cutoff,str): if expression_cutoff=='median': cutoff=data.stack().median() elif expression_cutoff=='mean': cutoff=data.stack().mean() else: # quantile, such as p5,p95 f=float(expression_cutoff.replace('p','')) cutoff=data.stack().quantile(f/100) expression_cutoff=cutoff data[groupby]=adata.obs.loc[data.index.tolist(),groupby].tolist() if not parent_col is None and parent_col in adata.obs.columns.tolist(): group2parent=adata.obs.loc[:,[groupby,parent_col]].drop_duplicates().set_index(groupby)[parent_col].to_dict() plot_data=data.groupby(groupby).mean().stack().reset_index() plot_data.columns=[groupby,'Gene','Mean'] if 'frac' in adata.layers: D=adata.to_df(layer='frac').stack().to_dict() else: if modality not in ['RNA','ATAC']: # methylation, cutoff = 1 assert normalize_per_cell==True,"Normalized methylation fraction is required" hypo_frac=data.groupby(groupby).agg(lambda x:x[x< 1].shape[0] / x.shape[0]) # fraction of cells showing hypomethylation for the corresponding genes D=hypo_frac.stack().to_dict() else: # for RNA print(f"Using expression cutoff: {expression_cutoff}") frac=data.groupby(groupby).agg(lambda x:x[x>expression_cutoff].shape[0] / x.shape[0]) # raw count > expression_cutoff means the gene is expressed D=frac.stack().to_dict() plot_data['frac']=plot_data.loc[:,[groupby,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) # plot_data df_cols=pd.DataFrame(list(sorted(adata.obs[groupby].unique().tolist())),columns=[groupby]) if not parent_col is None: df_cols[parent_col]=df_cols[groupby].map(group2parent) df_cols.sort_values([parent_col,groupby],inplace=True) df_cols.index=df_cols[groupby].tolist() if not cell_type_order is None: rows=[ct for ct in cell_type_order if ct in df_cols.index.tolist()] df_cols=df_cols.loc[rows] col_ha_dict={} if '+' in groupby: individual_groups=groupby.split('+') for ig in individual_groups: df_cols[ig]=df_cols[groupby].apply(lambda x:x.split('+')[individual_groups.index(ig)]) group_colors={} for k in df_cols[ig].unique().tolist(): group_colors[k]=color_palette[ig][k] col_ha_dict[ig]=anno_simple(df_cols[ig],colors=group_colors, add_text=False,legend=group_legend,height=3,label=ig) df_cols.dropna(inplace=True) # df_cols.head() if not parent_col is None: parent_colors={} axis=1 if not transpose else 0 # 1 for vertical (col annotation), 0 for horizontal for k in df_cols[parent_col].unique().tolist(): parent_colors[k]=color_palette[parent_col][k] if '+' not in groupby: group_colors={} for k in df_cols[groupby].unique().tolist(): group_colors[k]=color_palette[groupby][k] col_ha=HeatmapAnnotation(axis=axis, label=anno_label(df_cols[groupby], colors=group_colors,merge=True, rotation=45,fontsize=12,arrowprops = dict(visible=False)), group=anno_simple(df_cols[groupby],colors=group_colors, add_text=False,legend=group_legend,height=3,label=groupby), parent=anno_simple(df_cols[parent_col],colors=parent_colors, add_text=False,legend=parent_legend,height=3,label=parent_col), ) else: col_ha_dict[parent_col]=anno_simple(df_cols[parent_col],colors=parent_colors, add_text=False,legend=parent_legend,height=3,label=parent_col) col_ha = HeatmapAnnotation(**col_ha_dict,axis=axis, verbose=0) colnames=False else: axis=1 if not transpose else 0 # 1 for vertical (col annotation), 0 for horizontal if '+' not in groupby: group_colors={} for k in df_cols[groupby].unique().tolist(): group_colors[k]=color_palette[groupby][k] col_ha=HeatmapAnnotation(axis=axis, group=anno_simple(df_cols[groupby],colors=group_colors, add_text=False,legend=group_legend,height=3,label=groupby), ) else: col_ha = HeatmapAnnotation(**col_ha_dict,axis=axis, verbose=0) colnames=True if not transpose: top_annotation=col_ha left_annotation=None x=groupby y='Gene' x_order=df_cols.index.tolist() y_order=gene_order show_colnames=colnames show_rownames=True else: top_annotation=None left_annotation=col_ha y=groupby x='Gene' y_order=df_cols.index.tolist() x_order=gene_order show_rownames=colnames show_colnames=True default_plot_kws=dict( marker=marker,grid=None,legend_gap=8,dot_legend_marker=marker,cmap_legend_kws=legend_kws, row_cluster=row_cluster,col_cluster=col_cluster, row_cluster_method='ward',row_cluster_metric='euclidean', col_cluster_method='ward',col_cluster_metric='euclidean', col_names_side='top',row_names_side='left', show_rownames=show_rownames,show_colnames=show_colnames,row_dendrogram=False, # vmin=0,vmax=1.5, xticklabels_kws={'labelrotation': 45, 'labelcolor': 'blue','labelsize':10,'top':True}, yticklabels_kws={'labelcolor': 'blue','labelsize':10,'left':True}, spines=False, ) for k in default_plot_kws: if k not in plot_kws: plot_kws[k]=default_plot_kws[k] plt.figure(figsize=figsize) cm1 = DotClustermapPlotter( data=plot_data, top_annotation=top_annotation,left_annotation=left_annotation, x_order=x_order,y_order=y_order, x=x,y=y,value='Mean',c='Mean',s='frac', cmap=cmap,verbose=1,**plot_kws, ) for ax in cm1.heatmap_axes.ravel(): ax.grid(axis='both', which='minor', color='grey', linestyle='--',alpha=0.6,zorder=0) if outname is None: outname=f"{title}.pdf" plt.savefig(os.path.expanduser(outname),bbox_inches='tight') plt.show() return plot_data,df_cols,cm1
[docs] def get_genes_mean_frac( adata,obs=None,groupby='Subclass',modality='RNA',layer="mean", use_raw=False,expression_cutoff='p5', genes=None, normalize_per_cell=True,clip_norm_value=10,hypo_score=False, ): assert not genes is None, "Please provide genes to plot." # adata could be single cell level or pseudobulk level (adata.layers['frac'] should be existed) if isinstance(adata,str): adata=anndata.read_h5ad(os.path.expanduser(adata), backed='r') all_vars=set(adata.var_names.tolist()) keep_genes=list(set(all_vars) & set(genes)) # keep_genes=[g for g in all_vars if g in genes] error_genes=[g for g in genes if g not in keep_genes] if len(error_genes)>0: logger.debug(f"genes not found in adata: {error_genes}") use_adata = adata[:, keep_genes].to_memory() # type: ignore if adata.isbacked: adata.file.close() # close the file to save memory if 'mean' not in use_adata.layers: #raw count of single cell level adata # calculate mean and frac for each gene from single cell data if use_raw and not use_adata.raw is None: # use_adata.X=use_adata.raw.X.copy() use_adata_raw=use_adata.raw[:,use_adata.var_names.tolist()].to_adata() use_adata.X=use_adata_raw[use_adata.obs_names.tolist(),use_adata.var_names.tolist()].X.copy() # type: ignore del use_adata_raw if not obs is None: if isinstance(obs,str): sep='\t' if obs.endswith('.tsv') or obs.endswith('.txt') else ',' obs=pd.read_csv(os.path.expanduser(obs), sep=sep,index_col=0) assert isinstance(obs,pd.DataFrame), "obs should be a dataframe or a path to a csv/tsv file." else: obs=use_adata.obs.copy() overlapped_cells=list(set(use_adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] use_adata=use_adata[overlapped_cells,:] # type: ignore if modality not in ['RNA','ATAC'] and normalize_per_cell: use_adata = normalize_mc_by_cell( use_adata=use_adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) data=use_adata.to_df() # rows are cells or cell types, columns are genes if modality in ['RNA','ATAC'] and isinstance(expression_cutoff,str): if expression_cutoff=='median': cutoff=data.stack().median() elif expression_cutoff=='mean': cutoff=data.stack().mean() else: # quantile, such as p5,p95 f=float(expression_cutoff.replace('p','')) cutoff=data.stack().quantile(f/100) expression_cutoff=cutoff data[groupby]=obs.loc[data.index.tolist(),groupby].tolist() # type: ignore plot_data=data.groupby(groupby).mean().stack().reset_index() plot_data.columns=[groupby,'Gene','Mean'] if 'frac' in use_adata.layers: D=use_adata.to_df(layer='frac').stack().to_dict() else: if modality not in ['RNA','ATAC']: # methylation, cutoff = 1 assert normalize_per_cell==True,"Normalized methylation fraction is required" hypo_frac=data.groupby(groupby).agg(lambda x:x[x< 1].shape[0] / x.shape[0]) # fraction of cells showing hypomethylation for the corresponding genes D=hypo_frac.stack().to_dict() else: # for RNA logger.info(f"Using expression cutoff: {expression_cutoff}") frac=data.groupby(groupby).agg(lambda x:x[x>expression_cutoff].shape[0] / x.shape[0]) # raw count > expression_cutoff means the gene is expressed D=frac.stack().to_dict() plot_data['frac']=plot_data.loc[:,[groupby,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) else: plot_data=use_adata.to_df(layer=layer).stack().reset_index() plot_data.columns=[groupby,'Gene','Mean'] D=use_adata.to_df(layer='frac').stack().to_dict() plot_data['frac']=plot_data.loc[:,[groupby,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) return plot_data
[docs] def interactive_dotHeatmap( adata=None,obs=None,genes=None,groupby='Subclass', modality="RNA",title=None,use_raw=False, expression_cutoff='p5',normalize_per_cell=True, clip_norm_value=10, width=900,height=700,gene_order=None,colorscale='greens', vmin='p1',vmax='p99',show=True, reversescale=False,size_min=5,size_max=30, renderer="notebook" ): if not renderer is None: pio.renderers.default = renderer plot_data=get_genes_mean_frac( adata,obs=obs,groupby=groupby,modality=modality, use_raw=use_raw,expression_cutoff=expression_cutoff, genes=genes, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=False, ) # columns: [groupby,'Gene','Mean','frac'] # Build a Plotly dot-heatmap using scatter markers on categorical axes. # x: groups (columns), y: genes (rows) x_labels = plot_data[groupby].unique().tolist() if gene_order is None: y_labels = plot_data['Gene'].unique().tolist() else: y_labels = [g for g in gene_order if g in plot_data['Gene'].unique()] # Ensure ordering plot_data['x_cat'] = pd.Categorical(plot_data[groupby], categories=x_labels) plot_data['y_cat'] = pd.Categorical(plot_data['Gene'], categories=y_labels) # marker sizes: scale 'frac' (0-1) to reasonable pixel sizes frac_vals = plot_data['frac'].fillna(0).astype(float) sizes = (frac_vals * (size_max - size_min) + size_min).tolist() # marker colors: use Mean mean_vals = plot_data['Mean'].astype(float).tolist() hover_text = [f"Group: {g}<br>Gene: {ge}<br>Mean: {m:.4g}<br>Frac: {f:.3g}" for g,ge,m,f in zip(plot_data[groupby].tolist(), plot_data['Gene'].tolist(), mean_vals, frac_vals)] vmin_quantile=float(int(vmin.replace('p','')) / 100) vmax_quantile=float(int(vmax.replace('p','')) / 100) marker_dict = dict(size=sizes, color=mean_vals, colorscale=colorscale, showscale=True,colorbar=dict(title='Mean'), reversescale=reversescale, sizemode='area', opacity=0.9, cmin=plot_data['Mean'].quantile(vmin_quantile), cmax=plot_data['Mean'].quantile(vmax_quantile) ) fig = go.Figure() fig.add_trace(go.Scatter( x=plot_data[groupby].tolist(), y=plot_data['Gene'].tolist(), mode='markers', marker=marker_dict, text=hover_text, hoverinfo='text' )) # Layout: categorical axes with explicit ordering fig.update_xaxes(type='category', categoryorder='array', categoryarray=x_labels, tickangle= -45) fig.update_yaxes(type='category', categoryorder='array', categoryarray=list(reversed(y_labels))) if title is None: title=groupby fig.update_layout(title=title or '', xaxis_title=groupby, yaxis_title='Gene', width=width, height=height, plot_bgcolor='white') if show: filename=f"dotHeatmap.{groupby}" show_fig(fig,filename=filename) else: return fig
[docs] def get_boxplot_data(adata,variable,gene,obs=None): assert isinstance(adata,anndata.AnnData) if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if isinstance(obs,str): obs_path = os.path.abspath(os.path.expanduser(obs)) sep='\t' if obs_path.endswith('.tsv') or obs_path.endswith('.txt') else ',' data = pd.read_csv(obs_path, index_col=0,sep=sep) else: assert isinstance(obs,pd.DataFrame) data=obs.copy() overlap_idx=data.index.intersection(use_adata.obs_names) data=data.loc[overlap_idx] use_adata=use_adata[overlap_idx,:] # type: ignore if not gene is None: data[gene]=use_adata.to_df()[gene].tolist() # type: ignore return data.loc[:,[variable,gene]]
[docs] def has_stats(adata): if isinstance(adata,str): adata=anndata.read_h5ad(adata,backed='r') flag=True for k in ['min','q25','q50','q75','max','mean','std']: if not k in adata.layers: flag=False break return flag
[docs] def plot_interactive_boxlot_from_data( adata,obs,variable,gene,palette_path=None, width=1100,height=700,title=None, ): plot_df = get_boxplot_data(adata,variable,gene,obs=obs) # Preserve existing Y-axis extreme filtering logic (remove 1% and 99% extremes) range_y=[plot_df[gene].quantile(0.01), plot_df[gene].quantile(0.99)] # color_discrete_map=get_colors(adata,variable,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=adata,groups=variable) if not color_discrete_map is None: keys=list(color_discrete_map.keys()) # type: ignore for k in keys: if not k in plot_df[variable].unique().tolist(): del color_discrete_map[k] # type: ignore if title is None: title=f"Boxplot: {gene} by {variable}" fig = px.box( plot_df, x=variable, y=gene, color=variable, color_discrete_sequence=px.colors.qualitative.D3, # color palette (professional, unobtrusive) color_discrete_map=color_discrete_map, range_y=range_y, points=False, title=title, template="plotly_white" # keep white background style ) fig.update_xaxes(tickangle=-90, automargin=True) fig.update_traces( line_width=1.2, # thinner lines for a more refined look notched=False # no notch, standard boxplot style ) fig.update_layout( xaxis_title=variable, yaxis_title=gene, legend_title=variable, width=width, height=height ) return fig
[docs] def plot_interacrive_boxplot_from_stats( adata,variable,gene,palette_path=None, title=None,width=1100,height=700): assert isinstance(adata,anndata.AnnData) if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if adata.isbacked: # type: ignore adata.file.close() # type: ignore stat_keys=['min','q25','q50','q75','max','mean','std'] plot_data=[] for k in stat_keys: df=use_adata.to_df(layer=k)[gene] df.name=k plot_data.append(df) plot_data=pd.concat(plot_data,axis=1) # rows are cell types and columns are stats groups=plot_data.sort_values('q50').index.tolist() plot_data=plot_data.loc[groups] # build figure with one Box per group using precomputed quartiles/fences fig = go.Figure() # optional color mapping # color_discrete_map=get_colors(adata,variable,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=adata,groups=variable) palette = px.colors.qualitative.D3 color = None for group, row in plot_data.iterrows(): i=groups.index(group) q1 = row['q25'] med = row['q50'] q3 = row['q75'] low = row['min'] high = row['max'] mean = row['mean'] std = row['std'] if color_discrete_map is not None and group in color_discrete_map: color = color_discrete_map[group] else: color = palette[i % len(palette)] # Box from precomputed stats (single-element arrays) fig.add_trace( go.Box( x=[group], q1=[q1], median=[med], q3=[q3], lowerfence=[low], upperfence=[high], boxpoints=False, marker=dict(color=color), name=str(group), showlegend=True ) ) # mean as a scatter point with std error bar # fig.add_trace( # go.Scatter( # x=[group], # y=[mean], # mode='markers', # marker=dict(symbol='diamond', size=8, color='black'), # error_y=dict(type='data', array=[std], visible=True), # name='mean', # showlegend=False # ) # ) if title is None: title=f"Boxplot: {gene} by {variable}" fig.update_xaxes(tickangle=-90, automargin=True) fig.update_layout( title=title, xaxis_title=variable, yaxis_title=gene, legend_title=variable, template='plotly_white', width=width, height=height ) return fig
[docs] def interactive_boxplot( adata,variable,gene,obs=None,palette_path=None, title=None,width=1100,height=700,show=True,renderer='notebook'): if not renderer is None: pio.renderers.default = renderer if isinstance(adata,str): adata=anndata.read_h5ad(adata,backed='r') else: assert isinstance(adata,anndata.AnnData) if obs is None: obs=adata.obs.copy() # type: ignore if not has_stats(adata): fig=plot_interactive_boxlot_from_data( adata,obs,variable,gene,palette_path=palette_path, title=title,width=width,height=height ) else: # pseudobulk level with precomputed stats fig=plot_interacrive_boxplot_from_stats( adata,variable,gene,palette_path=palette_path, title=title,width=width,height=height) if show: filename=f"boxplot.{variable}.{gene}" show_fig(fig,filename=filename) return None else: return fig