Source code for adataviz.plotting

from re import L
import os, sys
import pandas as pd
import anndata
import scanpy as sc
import matplotlib.pylab as plt
import numpy as np
from matplotlib.colors import Normalize
import seaborn as sns
from .utils import (
	_make_tiny_axis_label, despine,
	zoom_ax, _extract_coords,
	_density_based_sample,_auto_size,
	_take_data_series, level_one_palette, 
	tight_hue_range,_text_anno_scatter,
	density_contour,plot_color_dict_legend,
	plot_marker_legend,plot_text_legend,plot_cmap_legend,
	normalize_mc_by_cell
)
from .tools import load_adata,load_obs,load_color_palette
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from loguru import logger as logger
# logger.add(sys.stderr, level="DEBUG")
logger.add(sys.stderr, level="ERROR")

[docs] def use_scientific_style(): import matplotlib.pylab as plt plt.rcParams.update({ 'font.family': 'sans-serif', 'font.sans-serif': 'Arial', 'font.family': 'Arial', # Base font size (set so text is ~7–8 pt at final print size) 'font.size': 8, # main text (labels, ticks, legend) 'axes.labelsize': 9, # axis labels (x/y) 'axes.titlesize': 10, # panel titles / figure titles 'figure.titlesize':11, 'xtick.labelsize': 8, # x-tick labels 'ytick.labelsize': 8, # y-tick labels 'legend.fontsize': 8, # legend text 'legend.title_fontsize': 9, # legend title (if used) # Lines and elements for clarity 'lines.linewidth': 1.2, # data lines 'axes.linewidth': 1.0, # axis spines 'xtick.major.width': 1.0, 'ytick.major.width': 1.0, 'xtick.major.size': 4, 'ytick.major.size': 4, # General figure appearance 'figure.dpi': 300, # high resolution for export 'savefig.dpi': 300, # when using plt.savefig() 'figure.figsize': (6.5, 4.5), # example starting size (adjust to your needs; e.g., ~17 cm wide for full page) 'figure.constrained_layout.use': True, 'savefig.transparent': True, 'savefig.bbox': 'tight', 'pdf.fonttype':42, 'ps.fonttype':42, })
# plt.rcParams.keys() # def get_colors(adata,variable=None,palette_path=None): # if not palette_path is None: # try: # colors=pd.read_excel(palette_path,sheet_name=variable,index_col=0).Hex.to_dict() # except: # return None # else: # if adata is None: # return None # if isinstance(adata,str): # adata=anndata.read_h5ad(adata,backed='r') # if f'{variable}_colors' not in adata.uns: # colors={cluster:color for cluster,color in zip(adata.obs[variable].cat.categories.tolist(),adata.uns[f'{variable}_colors'])} # else: # colors=None # color_discrete_map=colors # return color_discrete_map
[docs] def interactive_embedding( adata=None,obs=None,variable=None,gene=None, coord="umap",vmin='p1',vmax='p99',cmap='jet',title=None, width=900,height=750,colors=None,palette_path=None, size=None,show=True,downsample=None,target_fill=0.05, normalize_per_cell=True,clip_norm_value=10, renderer="notebook"): """ Plot interactive embedding plot with plotly for a given AnnData object or path of .h5ad. Parameters ---------- adata : _type_ _description_ obs : _type_, optional _description_, by default None variable : _type_, optional _description_, by default None gene : _type_, optional _description_, by default None coord : str, optional _description_, by default "umap" vmin : str, optional _description_, by default 'p1' vmax : str, optional _description_, by default 'p99' cmap : str, optional _description_, by default 'jet' title : _type_, optional _description_, by default None width : int, optional _description_, by default 1000 height : int, optional _description_, by default 800 colors : _type_, optional _description_, by default None palette_path : _type_, optional _description_, by default None size : _type_, optional _description_, by default None target_fill : float, optional _description_, by default 0.05 show : bool, optional _description_, by default True renderer : str, optional _description_, by default "notebook" Available renderers: ['plotly_mimetype', 'jupyterlab', 'nteract', 'vscode', 'notebook', 'notebook_connected', 'kaggle', 'azure', 'colab', 'cocalc', 'databricks', 'json', 'png', 'jpeg', 'jpg', 'svg', 'pdf', 'browser', 'firefox', 'chrome', 'chromium', 'iframe', 'iframe_connected', 'sphinx_gallery', 'sphinx_gallery_png'] Returns ------- _type_ _description_ """ if not renderer is None: pio.renderers.default = renderer use_col=gene if not gene is None else variable if not gene is None: assert not adata is None, "`gene` provided, `adata` must be provided too." if not adata is None: adata=load_adata(adata) use_adata=None if not gene is None: # adata is not None if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if normalize_per_cell: use_adata = normalize_mc_by_cell( use_adata=use_adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=False) else: if not adata is None: use_adata=adata # backed if obs is None and use_adata is None: raise ValueError("Either `adata` or `obs` must be provided.") if obs is None: obs=use_adata.obs.copy() # type: ignore else: # obs not none obs=load_obs(obs) if not use_adata is None: overlap_idx=obs.index.intersection(use_adata.obs_names) obs=obs.loc[overlap_idx] use_adata=use_adata[overlap_idx,:] # type: ignore if not gene is None: obs[gene]=use_adata.to_df()[gene].tolist() # type: ignore cols=set(obs.columns.tolist()) if not f'{coord}_0' in cols or not f'{coord}_1' in cols: assert f'X_{coord}' in use_adata.obsm # type: ignore # print(use_adata.obsm[f'X_{coord}']) obs[f'{coord}_0']=use_adata.obsm[f'X_{coord}'][:,0].tolist() # type: ignore obs[f'{coord}_1']=use_adata.obsm[f'X_{coord}'][:,1].tolist() # type: ignore # print(obs.head()) if not adata is None and adata.isbacked: # type: ignore adata.file.close() # type: ignore # downsample obs for large dataset n_points = obs.shape[0] if not downsample is None and n_points > downsample: sample_idx = np.random.choice(n_points, size=downsample, replace=False) # numbers obs = obs.iloc[sample_idx] if not obs.dtypes[use_col] in ['object','category']: vmin_quantile=float(int(vmin.replace('p','')) / 100) vmax_quantile=float(int(vmax.replace('p','')) / 100) # print(vmin_quantile,vmax_quantile,obs[use_col],obs.dtypes[use_col]) range_color=[obs[use_col].quantile(vmin_quantile), obs[use_col].quantile(vmax_quantile)] color_discrete_map=None else: if colors is None: # color_discrete_map=get_colors(use_adata,use_col,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=use_adata,groups=use_col) else: color_discrete_map=colors if not color_discrete_map is None: keys=list(color_discrete_map.keys()) # type: ignore for k in keys: if k not in obs[use_col].unique().tolist(): del color_discrete_map[k] # type: ignore range_color=None keep_cols=['cell',f'{coord}_0',f'{coord}_1'] if not variable is None: keep_cols.append(variable) if not gene is None: keep_cols.append(gene) obs=obs.reset_index(names="cell").loc[:,keep_cols] # Create Plotly interactive scatter plot hover_data={ # Fields to show on hover "cell": True, # cell ID f'{coord}_0': ":0.3f",# UMAP coordinates rounded to 3 decimals f'{coord}_1': ":0.3f", } if not variable is None: hover_data[variable]=True # type: ignore # when plotting gene expression, also show cell types when mouse hover if not gene is None: hover_data[gene]=":.3f" # type: ignore fig = px.scatter( obs, x=f'{coord}_0', # UMAP first dimension → X axis y=f'{coord}_1', # UMAP second dimension → Y axis color=use_col, hover_data=hover_data, range_color=range_color, color_discrete_sequence=px.colors.qualitative.D3, # color palette (professional, unobtrusive) color_discrete_map=color_discrete_map, color_continuous_scale=cmap, #["blue", "red"], template="plotly_white", render_mode='webgl' # use WebGL rendering for better performance with large datasets ) fig.update_xaxes(range=[obs[f'{coord}_0'].min()-0.5, obs[f'{coord}_0'].max()+0.5],tickfont_size=12) fig.update_yaxes(range=[obs[f'{coord}_1'].min()-0.5, obs[f'{coord}_1'].max()+0.5],tickfont_size=12) if size is None: # Blend an area-based marker estimate with a log-based fallback so total point count and canvas size both matter. # Increased target_fill and scaling to make markers bigger marker_diam_area = 2 * np.sqrt((width * height * target_fill) / (np.pi * n_points)) marker_diam_log = 16 - 2 * np.log10(n_points) marker_diam = 0.7 * marker_diam_area + 0.5 * marker_diam_log size = int(np.clip(marker_diam, 4, 20)) if n_points < 500000: opacity = 0.8 else: opacity = 0.6 # logger.debug(f"{variable},{gene},{use_col}") # print(color_discrete_map,size,opacity) fig.update_traces( marker=dict(size=size, opacity=opacity, line=dict(width=0.12, color='black')), selector=dict(mode='markers') ) if title is None: title = f"{coord.upper()} Visualization (Colored by {use_col})" fig.update_layout( title=dict( text=title, font_size=16, x=0.5, # center the title pad=dict(t=10) ), xaxis_title=f'{coord}_0'.upper(), yaxis_title=f'{coord}_1'.upper(), autosize=True,width=width,height=height, legend_title=use_col, # legend title legend=dict( font_size=12, itemsizing='constant', # important: fix legend marker size so it's not affected by scatter points itemwidth=30, borderwidth=0.1 # legend item width; larger value increases the marker size ) ) if show: filename=f"{coord}.{use_col}" show_fig(fig,filename=filename) else: return fig
# html=fig2div(fig,filename='umap_plot') # return HttpResponse(html)
[docs] def show_fig(fig,filename="plot"): interactive_config={ 'displayModeBar':'hover','showLink':False,'linkText':'Edit on plotly', 'scrollZoom':True,"displaylogo": False, 'toImageButtonOptions':{'format':'svg','filename':filename}, 'modeBarButtonsToRemove':['sendDataToCloud'], # 'zoomIn2d','zoomOut2d','zoom2d','zoom3d','pan2d' 'editable':True,'autosizable':True,'responsive':True, 'fillFrame':True, 'edits':{ 'titleText':True,'legendPosition':True,'colorbarTitleText':True, 'shapePosition':True,'annotationPosition':True,'annotationText':True, 'axisTitleText':True,'legendText':True,'colorbarPosition':True} } fig.show(config=interactive_config)
[docs] def stacked_barplot( obs="cell_obs_with_annotation.csv",groupby='Age', column='CellClass',x_order=None,y_order=None,linewidth=0.1, palette="~/Projects/mouse_pfc/obs/mpfc_color_palette.xlsx", width=None,height=None,xticklabels_kws=None,save=False, lgd_kws=None,gap=0.05,sort_by=None): """ Plot stacked barplto to show the cell type composition in each `groupby` ( such as Age, brain regions and so on.) For example: stacked_barplot(column='MajorType',width=3.5,height=6) stacked_barplot(column='CellClass',width=3.5,height=3) """ if isinstance(obs,pd.DataFrame): data=obs.copy() elif isinstance(obs, str) and obs.endswith('.h5ad'): obs_path = os.path.abspath(os.path.expanduser(obs)) adata = anndata.read_h5ad(obs_path,backed='r') data=adata.obs del adata elif obs.endswith('.csv'): obs_path = os.path.abspath(os.path.expanduser(obs)) data=pd.read_csv(obs_path,index_col=0) else: obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0,sep='\t') xticklabels_kws={} if xticklabels_kws is None else xticklabels_kws xticklabels_kws.setdefault('rotation',-45) xticklabels_kws.setdefault("rotation_mode", "anchor") xticklabels_kws.setdefault('horizontalalignment', 'left') #see ?matplotlib.axes.Axes.tick_params xticklabels_kws.setdefault('verticalalignment', 'center') if not palette is None: if isinstance(palette,dict): color_palette=palette.copy() elif isinstance(palette,str) and os.path.exists(os.path.expanduser(palette)): palette=os.path.abspath(os.path.expanduser(palette)) D=pd.read_excel(palette, sheet_name=None, index_col=0) color_palette=D[column].Hex.to_dict() keys=list(color_palette.keys()) for k in data[column].unique(): if k not in keys: color_palette[k]='gray' else: color_palette = palette else: color_palette=None df=data.groupby(groupby)[column].value_counts(normalize=True).unstack(level=column) if not sort_by is None: df.sort_values(sort_by,ascending=True,inplace=True) else: if x_order is None: x_order = sorted(df.index.tolist()) if y_order is None: y_order = sorted(df.columns.tolist()) df=df.loc[x_order,y_order] if width is None: width=max(df.shape[0]*0.45,10) if width < 2.5: width=2.5 if height is None: height = max(df.shape[1]*0.5, 8) if height < 3.5: height = 3.5 plt.figure() ax=df.plot.bar(stacked=True,align='edge', width=1-gap,edgecolor='black', linewidth=linewidth,figsize=(width,height), color=color_palette) ax.set_xlim(0,df.shape[0]) ax.set_ylim(0,1) labels=[tick.get_text() for tick in ax.get_xticklabels()] ax.set_xticks(ticks=np.arange(0.5,df.shape[0],1), labels=labels,**xticklabels_kws) # ax.xaxis.set_major_locator(ticker.FixedLocator(np.arange(0.5,df.shape[1],1))) #ticker.MultipleLocator(0.5) ax.xaxis.label.set_visible(False) ax.tick_params( axis="y", #both which="both",left=False,right=False,labelleft=False,labelright=False, top=False,labeltop=False,#bottom=False,labelbottom=False ) # ax.xaxis.set_tick_params(axis='x') lgd_kws = lgd_kws if not lgd_kws is None else {} # bbox_to_anchor=(x,-0.05) lgd_kws.setdefault("frameon", True) lgd_kws.setdefault("ncol", 1) lgd_kws["loc"] = "upper left" lgd_kws.setdefault("borderpad", 0.1 * (1 / 25.4) * 72) # 0.1mm lgd_kws.setdefault("markerscale", 1) lgd_kws.setdefault("handleheight", 1) # font size, units is points lgd_kws.setdefault("handlelength", 1) # font size, units is points lgd_kws.setdefault("borderaxespad", 0.1) # The pad between the axes and legend border, in font-size units. lgd_kws.setdefault("handletextpad", 0.3) # The pad between the legend handle and text, in font-size units. lgd_kws.setdefault("labelspacing", 0.1) # gap height between two Patches, 0.05*mm2inch*72 lgd_kws.setdefault("columnspacing", 1) lgd_kws.setdefault("bbox_to_anchor", (1, 1)) lgd_kws.setdefault("title",column) ax.legend(**lgd_kws) if save: outdir=os.path.dirname(os.path.expanduser(save)) if not os.path.exists(outdir): os.mkdir(outdir) plt.savefig(save) # transparent=True,bbox_inches='tight',dpi=300 else: plt.show()
[docs] def pieplot(obs="cell_obs_with_annotation.csv",groupby='Age',outdir="figures", palette_path="~/Projects/mouse_pfc/obs/mpfc_color_palette.xlsx", order=None,explode=0.05): outdir = os.path.abspath(os.path.expanduser(outdir)) if not os.path.exists(outdir): os.mkdir(outdir) # colors=None if isinstance(obs,pd.DataFrame): data=obs.copy() elif isinstance(obs, str) and obs.endswith('.h5ad'): obs_path = os.path.abspath(os.path.expanduser(obs)) print(f"Reading adata: {obs}") adata = anndata.read_h5ad(obs_path, backed='r') # if f'{groupby}_colors' in adata.uns: # colors={k:v for k,v in zip(adata.obs[groupby].cat.categories.tolist(), # adata.uns[f'{groupby}_colors'])} # else: # colors=None data = adata.obs del adata elif obs.endswith('.csv'): obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0) else: obs_path = os.path.abspath(os.path.expanduser(obs)) data = pd.read_csv(obs_path, index_col=0, sep='\t') if not palette_path is None: palette_path=os.path.abspath(os.path.expanduser(palette_path)) D=pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette=D[groupby].Hex.to_dict() else: color_palette=None D=data[groupby].value_counts() if order is None: order=list(sorted(D.keys())) plt.figure() plt.pie([D[k] for k in order], labels=order, colors=[color_palette[k] for k in order], explode=[explode]*len(order), autopct='%.1f%%') # Add title to the chart plt.title('Distribution of #of cells across different stages') plt.savefig(os.path.join(outdir, groupby + '.piechart.pdf')) # transparent=True,bbox_inches='tight',dpi=300 plt.show()
[docs] def plot_pseudotime( pseudotime="dpt_pseudotime.tsv",groupby='Age',y='dpt_pseudotime', hue=None,figsize=(5,3.5),outdir="figures",rotate=None,ylabel='Pseudotime', palette_path="~/Projects/mouse_pfc/obs/mpfc_color_palette.xlsx", ): """ Plot pseudotime. plot_pseudotime(figsize=(6,3.5),groupby='MajorType', rotate=-45); plot_pseudotime(figsize=(3.5,3),groupby='CellClass') plot_pseudotime(figsize=(3.5,3),groupby='Age') Parameters ---------- pseudotime : groupby : y : hue : figsize : outdir : rotate : palette_path : Returns ------- """ outdir = os.path.abspath(os.path.expanduser(outdir)) if not os.path.exists(outdir): os.mkdir(outdir) if not palette_path is None: palette_path=os.path.abspath(os.path.expanduser(palette_path)) D=pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette=D[groupby].Hex.to_dict() else: color_palette=None data=pd.read_csv(os.path.expanduser(pseudotime),sep='\t',index_col=0) data.dpt_pseudotime.replace({np.inf: 1},inplace=True) order=data.groupby(groupby)[y].mean().sort_values(ascending=True).index.tolist() if not hue is None: hue_order=data.groupby(hue)[y].mean().sort_values(ascending=True).index.tolist() else: hue_order=None plt.figure(figsize=figsize) # ax = sns.swarmplot(data=data, palette=color_palette, \ # edgecolor='white', x=groupby, y=y, \ # order=order) ax=sns.violinplot(data=data, x=groupby, y=y, scale='count', bw=.2, inner=None, saturation=0.6, palette=color_palette, order=order,hue=hue,hue_order=hue_order) # plt.legend(frameon=True) ax.set_ylabel(ylabel) if not rotate is None: plt.setp(ax.xaxis.get_majorticklabels(), rotation=rotate, rotation_mode="anchor",horizontalalignment='left') outname=groupby + '.pseudotime_violin.pdf' if hue is None else groupby + f'_{hue}.pseudotime_violin.pdf' plt.savefig(os.path.join(outdir, outname)) # transparent=True,bbox_inches='tight',dpi=300 plt.show()
[docs] def stacked_violinplot(adata, use_genes=None, groupby='Age', cell_groups=None, parent=None, figsize=(6, 4), cmap='viridis'): import scanpy as sc ax = sc.pl.stacked_violin( adata[adata.obs[cell_groups[0]]==parent], var_names=use_genes, title=use_key, colorbar_title="Avg mc frac", groupby=groupby, dendrogram=True, swap_axes=False, cmap=cmap, figsize=figsize, scale='count', standard_scale='obs', inner='quart', # stripplot=False,jitter=False, show=False, layer=None) ax1 = ax['mainplot_ax'] ax1.yaxis.set_minor_locator(ticker.MultipleLocator(1)) ax1.yaxis.set_tick_params(which='minor',left=True) ax1.grid(axis='y', linestyle='--', color='black', alpha=1, zorder=-5, which='minor') plt.savefig(f"{fig_basename}.{groupby}.stacked_violin.pdf") plt.show()
[docs] def categorical_scatter( data,ax=None, coord_base="umap",x=None,y=None, # coords hue=None,palette="auto",color=None, # color text_anno=None,text_kws=None,luminance=None,text_transform=None, dodge_text=False,dodge_kws=None, # text annotation show_legend=False,legend_kws=None, # legend s="auto",size=None,sizes=None, # sizes is a dict size_norm=None,size_portion=0.95, axis_format="tiny",max_points=50000, labelsize=4,linewidth=0.5,zoomxy=1.05, outline=None,outline_pad=3,alpha=0.7, outline_kws=None,scatter_kws=None, rasterized="auto",coding=False, id_marker=True,legend_color_text=True, rectangle_marker=False,marker_fontsize=4,marker_pad=0.1, ): """ This function was copied from ALLCools and made some modifications. Plot categorical scatter plot with versatile options. Parameters ---------- rasterized Whether to rasterize the figure. return_fig Whether to return the figure. size_portion The portion of the figure to be used for the size norm. data Dataframe that contains coordinates and categorical variables ax this function do not generate ax, must provide an ax coord_base coords name, if provided, will automatically search for x and y x x coord name y y coord name hue : str categorical col name or series for color hue. palette : str or dict palette for color hue. color specify single color for all the dots text_anno categorical col name or series for text annotation. text_kws kwargs pass to plt.text, see: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.text.html including bbox, to see parameter for bbox, go to: https://matplotlib.org/stable/api/_as_gen/matplotlib.patches.FancyBboxPatch.html#matplotlib.patches.FancyBboxPatch commonly used parameters are:: text_kws=dict(fontsize=5,fontweight='black', color='black', # color could be a dict, keys are text to be annotated bbox=dict(boxstyle='round',edgecolor=(0.5, 0.5, 0.5, 0.2),fill=False, facecolor=(0.8, 0.8, 0.8, 0.2), # facecolor could also be a dict alpha=1,linewidth=0.5) ) text_transform transform for text annotation. dodge_text whether to dodge text annotation. dodge_kws kwargs for dodge text annotation. show_legend whether to show legend. legend_kws kwargs for legend. s single size value of all the dots. size mappable size of the dots. sizes mapping size to the sizes value. size_norm normalize size range for mapping. axis_format axis format. max_points maximum number of points to plot. labelsize label size pass to `ax.text` linewidth line width pass to `ax.scatter` zoomxy zoom factor for x and y-axis. outline categorical col name or series for outline. outline_pad outline padding. outline_kws kwargs for outline. scatter_kws kwargs for scatter. Returns ------- if return_fig is True, return the figure and axes. else, return None. """ if ax is None: # fig, ax = plt.subplots(figsize=(4, 4), dpi=300) ax = plt.gca() # add coords _data, x, y = _extract_coords(data, coord_base, x, y) # _data has 2 cols: "x" and "y", index are obs_names # down sample plot data if needed. if max_points is not None: if _data.shape[0] > max_points: _data = _density_based_sample(_data, seed=1, size=max_points, coords=["x", "y"]) n_dots = _data.shape[0] # determine rasterized if rasterized == "auto": if n_dots > 200: rasterized = True else: rasterized = False # auto size if user didn't provide one if s == "auto": s = _auto_size(ax, n_dots) # default scatter options _scatter_kws = {"linewidth": 0, "s": s, "legend": None, "palette": palette, "rasterized": rasterized} if color is not None: if hue is not None: raise ValueError("Only one of color and hue can be provided") _scatter_kws["color"] = color if scatter_kws is not None: _scatter_kws.update(scatter_kws) # deal with color palette_dict = None if hue is not None: if isinstance(hue, str): _data["hue"] = _take_data_series(data, hue) else: _data["hue"] = hue.copy() _data["hue"] = _data["hue"].astype("category").cat.remove_unused_categories() # if the object has get_palette method, use it (AnnotZarr) palette = _scatter_kws["palette"] # deal with other color palette if palette_dict is None: if isinstance(palette, str) or isinstance(palette, list): palette_dict = level_one_palette(_data["hue"], order=None, palette=palette) elif isinstance(palette, dict): palette_dict = palette else: raise TypeError(f"Palette can only be str, list or dict, " f"got {type(palette)}") _scatter_kws["palette"] = palette_dict # deal with size if size is not None: if isinstance(size, str): _data["size"] = _take_data_series(data, size).astype(float) else: _data["size"] = size.astype(float) size = "size" if size_norm is None: # get the smallest range that include "size_portion" of data size_norm = tight_hue_range(_data["size"], size_portion) # snorm is the normalizer for size size_norm = Normalize(vmin=size_norm[0], vmax=size_norm[1]) # discard s from _scatter_kws and use size in sns.scatterplot s = _scatter_kws.pop("s") if sizes is None: sizes = (min(s, 1), s) sns.scatterplot( x="x", y="y", data=_data, ax=ax, hue="hue", size=size, sizes=sizes, size_norm=size_norm, **_scatter_kws, ) # deal with text annotation code2label=None if text_anno is not None: # data if isinstance(text_anno, str): _data["text_anno"] = _take_data_series(data, text_anno) else: _data["text_anno"] = text_anno.copy() if str(_data["text_anno"].dtype) == "category": _data["text_anno"] = _data["text_anno"].cat.remove_unused_categories() # text kws text_kws = {} if text_kws is None else text_kws default_text_kws = dict( color='white', # color for the text, could be a dict, keys are text to be annotated fontweight="bold", #fontsize=labelsize, bbox=dict(facecolor=palette_dict, # if None, use default color boxstyle='round', #ellipse, round edgecolor='white', fill=True, linewidth=linewidth, alpha=alpha)) # coding & id_marker text_anno='text_anno' if not coding is None and coding!=False: if coding == True: _data['code'] = _data['hue'].cat.codes #int else: assert isinstance(coding,str) _data["code"] = _take_data_series(data, coding) _data=_data.loc[_data['code'].notna()] _data["code"]=_data["code"].astype(int) _data["code"] = _data["code"].astype("category").cat.remove_unused_categories() text_anno='code' _data['color'] = _data['hue'].map(palette_dict) code2label=_data.loc[:,['code','hue']].drop_duplicates().set_index('code').hue.to_dict() _data['code']=_data['code'].astype(str) code_colors=_data.loc[:,['code','color']].drop_duplicates().set_index('code').color.to_dict() default_text_kws['bbox']['facecolor']=code_colors # background colors for text annotation default_text_kws['bbox']['boxstyle'] = 'circle' for k in default_text_kws: if k !='bbox': text_kws.setdefault(k, default_text_kws[k]) else: if 'bbox' not in text_kws: text_kws['bbox']={} for k1 in default_text_kws['bbox']: text_kws['bbox'].setdefault(k1, default_text_kws['bbox'][k1]) _text_anno_scatter( data=_data[["x", "y", text_anno]], ax=ax, x="x", y="y", dodge_text=dodge_text, dodge_kws=dodge_kws, text_transform=text_transform, anno_col=text_anno, text_kws=text_kws, luminance=luminance, ) # deal with outline if not outline is None: if isinstance(outline, str): _data["outline"] = _take_data_series(data, outline) else: _data["outline"] = outline.copy() _outline_kws = { "linewidth": linewidth, "palette": None, "c": "lightgray", "single_contour_pad": outline_pad, } if outline_kws is not None: _outline_kws.update(outline_kws) density_contour(ax=ax, data=_data, x="x", y="y", groupby="outline", **_outline_kws) # clean axis if axis_format == "tiny": _make_tiny_axis_label(ax, x, y, arrow_kws=None, fontsize=labelsize) elif (axis_format == "empty") or (axis_format is None): despine(ax=ax, left=True, bottom=True) ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None) else: pass # deal with legend if show_legend and (hue is not None): n_hue = len(palette_dict) ncol=1 if n_hue <= 40 else 2 if n_hue <= 100 else 3 if legend_kws is None: legend_kws = {} default_lgd_kws = dict( ncol=ncol,fontsize=labelsize, bbox_to_anchor=(1, 1),loc="upper left", # borderpad=0.4, # pad between marker (text) and border # labelspacing=0.2, #The vertical space between the legend entries, in font-size units # handleheight=0.5, #The height of the legend handles, in font-size units. # handletextpad=0.2, # The pad between the legend handle (marker) and text, in font-size units. # borderaxespad=0.3, # The pad between the Axes and legend border, in font-size units # columnspacing=0.2, #The spacing between columns, in font-size units markersize=labelsize #legend_kws["fontsize"], ) for k in default_lgd_kws: legend_kws.setdefault(k, default_lgd_kws[k]) exist_hues = _data["hue"].unique() color_dict={hue_name: color for hue_name, color in palette_dict.items() if hue_name in exist_hues} if not code2label is None and id_marker: boxstyle='Circle' if not rectangle_marker else 'Round' plot_text_legend(color_dict, code2label, ax, title=hue, color_text=legend_color_text, boxstyle=boxstyle,marker_pad=marker_pad, legend_kws=legend_kws,marker_fontsize=marker_fontsize, alpha=alpha,luminance=luminance) else: if rectangle_marker: ## plot Patch legend (rectangle marker) plot_color_dict_legend( D=color_dict, ax=ax, title=hue, color_text=legend_color_text, kws=legend_kws,luminance=luminance ) else: # plot marker legend (for example, circle marker) plot_marker_legend( color_dict=color_dict, ax=ax, title=hue, color_text=legend_color_text, marker='o',kws=legend_kws,luminance=luminance ) if zoomxy is not None: zoom_ax(ax, zoomxy) return _data
[docs] def get_cmap(cmap): try: return plt.colormaps.get(cmap) # matplotlib >= 3.5.1? except: return plt.get_cmap(cmap) # matplotlib <=3.4.3?
[docs] def continuous_scatter( data, ax=None, coord_base="umap", x=None, y=None, scatter_kws=None, hue=None, hue_norm=None, hue_portion=0.95, color=None, cmap="viridis", colorbar=True, size=None, size_norm=None, size_portion=0.95, sizes=None, sizebar=True, text_anno=None, dodge_text=False, dodge_kws=None, text_kws=None,luminance=0.48, text_transform=None, axis_format="tiny", max_points=50000, s="auto", labelsize=6, ticklabel_size=4, linewidth=0.5, zoomxy=1.05, outline=None, outline_kws=None, outline_pad=2, return_fig=False, rasterized="auto", cbar_kws=None,cbar_width=3, ): """ Plot scatter on given adata. Parameters ---------- data : _type_ _description_ ax : _type_, optional _description_, by default None coord_base : str, optional _description_, by default "umap" x : _type_, optional _description_, by default None y : _type_, optional _description_, by default None scatter_kws : _type_, optional _description_, by default None hue : _type_, optional _description_, by default None hue_norm : _type_, optional _description_, by default None hue_portion : float, optional _description_, by default 0.95 color : _type_, optional _description_, by default None cmap : str, optional _description_, by default "viridis" colorbar : bool, optional _description_, by default True size : _type_, optional _description_, by default None size_norm : _type_, optional _description_, by default None size_portion : float, optional _description_, by default 0.95 sizes : _type_, optional _description_, by default None sizebar : bool, optional _description_, by default True text_anno : _type_, optional _description_, by default None dodge_text : bool, optional _description_, by default False dodge_kws : _type_, optional _description_, by default None text_kws : _type_, optional _description_, by default None luminance : float, optional _description_, by default 0.48 text_transform : _type_, optional _description_, by default None axis_format : str, optional _description_, by default "tiny" max_points : int, optional _description_, by default 50000 s : str, optional _description_, by default "auto" labelsize : int, optional _description_, by default 6 ticklabel_size : int, optional _description_, by default 4 linewidth : float, optional _description_, by default 0.5 zoomxy : float, optional _description_, by default 1.05 outline : _type_, optional _description_, by default None outline_kws : _type_, optional _description_, by default None outline_pad : int, optional _description_, by default 2 return_fig : bool, optional _description_, by default False rasterized : str, optional _description_, by default "auto" cbar_kws : _type_, optional _description_, by default None cbar_width : int, optional width of colorbar, by default 3 mm Returns ------- _type_ _description_ Raises ------ ValueError _description_ TypeError _description_ """ import seaborn as sns import copy from matplotlib.cm import ScalarMappable # init figure if not provided if ax is None: fig, ax = plt.subplots(figsize=(4, 4), dpi=300) else: fig = None # add coords _data, x, y = _extract_coords(data, coord_base, x, y) # _data has 2 cols: "x" and "y" # down sample plot data if needed. if max_points is not None: if _data.shape[0] > max_points: _data = _density_based_sample(_data, seed=1, size=max_points, coords=["x", "y"]) n_dots = _data.shape[0] # determine rasterized if rasterized == "auto": if n_dots > 200: rasterized = True else: rasterized = False # auto size if user didn't provide one if s == "auto": s = _auto_size(ax, n_dots) # default scatter options _scatter_kws = {"linewidth": 0, "s": s, "legend": None, "rasterized": rasterized} if color is not None: if hue is not None: raise ValueError("Only one of color and hue can be provided") _scatter_kws["color"] = color if scatter_kws is not None: _scatter_kws.update(scatter_kws) # deal with color if hue is not None: if isinstance(hue, str): _data["hue"] = _take_data_series(data, hue).astype(float) colorbar_label = hue else: _data["hue"] = hue.astype(float) colorbar_label = hue.name if hue_norm is None: # get the smallest range that include "hue_portion" of data # hue_norm = tight_hue_range(_data["hue"], hue_portion) hue_norm=(_data["hue"].quantile(1-hue_portion),_data["hue"].quantile(hue_portion)) # cnorm is the normalizer for color cnorm = Normalize(vmin=hue_norm[0], vmax=hue_norm[1]) if isinstance(cmap, str): # from here, cmap become colormap object cmap = copy.copy(get_cmap(cmap)) cmap.set_bad(color=(0.5, 0.5, 0.5, 0.5)) else: if not isinstance(cmap, ScalarMappable): raise TypeError(f"cmap can only be str or ScalarMappable, got {type(cmap)}") else: hue_norm = None cnorm = None colorbar_label = "" # deal with size if size is not None: if isinstance(size, str): _data["size"] = _take_data_series(data, size).astype(float) else: _data["size"] = size.astype(float) size = "size" if size_norm is None: # get the smallest range that include "size_portion" of data size_norm = tight_hue_range(_data["size"], size_portion) # snorm is the normalizer for size size_norm = Normalize(vmin=size_norm[0], vmax=size_norm[1]) # replace s with sizes s = _scatter_kws.pop("s") if sizes is None: sizes = (min(s, 1), s) else: size_norm = None sizes = None sns.scatterplot( x="x", y="y", data=_data, hue="hue", palette=cmap, hue_norm=cnorm, size=size, sizes=sizes, size_norm=size_norm, ax=ax, **_scatter_kws, ) if text_anno is not None: if isinstance(text_anno, str): _data["text_anno"] = _take_data_series(data, text_anno) else: _data["text_anno"] = text_anno if str(_data["text_anno"].dtype) == "category": _data["text_anno"] = _data["text_anno"].cat.remove_unused_categories() _text_anno_scatter( data=_data[["x", "y", "text_anno"]], ax=ax, x="x", y="y", dodge_text=dodge_text, dodge_kws=dodge_kws, text_transform=text_transform, anno_col="text_anno", text_kws=text_kws, luminance=luminance, ) # deal with outline if outline: if isinstance(outline, str): _data["outline"] = _take_data_series(data, outline) else: _data["outline"] = outline _outline_kws = { "linewidth": linewidth, "palette": None, "c": "lightgray", "single_contour_pad": outline_pad, } if outline_kws is not None: _outline_kws.update(outline_kws) density_contour(ax=ax, data=_data, x="x", y="y", groupby="outline", **_outline_kws) # clean axis if axis_format == "tiny": _make_tiny_axis_label(ax, x, y, arrow_kws=None, fontsize=labelsize) elif (axis_format == "empty") or (axis_format is None): despine(ax=ax, left=True, bottom=True) ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None) else: pass return_axes = [ax] # make color bar if colorbar and (hue is not None): # small ax for colorbar # default_cbar_kws=dict(loc="upper left", borderpad=0,width="3%", height="20%") #bbox_to_anchor=(1,1) if cbar_kws is None: cbar_kws={} # for k in default_cbar_kws: # if k not in cbar_kws: # cbar_kws[k]=default_cbar_kws[k] mm2inch = 1 / 25.4 space=0 legend_width = ( cbar_width * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().width ) # mm to px to fraction pad = (space + ax.yaxis.labelpad * 1.2 * ax.figure.dpi / 72) / ax.figure.get_window_extent().width # labelpad unit is points left = ax.get_position().x1 + pad ax_legend = ax.figure.add_axes( [left, ax.get_position().height * 0.8, legend_width, ax.get_position().height * 0.2] ) # left, bottom, width, height # print("test:",hue_norm) # cbar_kws.setdefault('vmin',hue_norm[0]) # cbar_kws.setdefault('vmax',hue_norm[1]) cbar_kws['vmin']=hue_norm[0] cbar_kws['vmax']=hue_norm[1] cbar = plot_cmap_legend( ax=ax, cax=ax_legend, cmap=cmap, label=hue, kws=cbar_kws.copy(),labelsize=labelsize, linewidth=linewidth,ticklabel_size=ticklabel_size, ) return_axes.append([ax_legend,cbar]) # make size bar if sizebar and (size is not None): # TODO plot dot size bar pass if zoomxy is not None: zoom_ax(ax, zoomxy) if return_fig: return (fig, tuple(return_axes)), _data else: return
[docs] def plot_cluster( adata_path,ax=None,coord_base='tsne',cluster_col='MajorType', palette_path=None,coding=True,id_marker=True, output=None, show=True,figsize=(4, 3.5),sheet_name=None, ncol=None,fontsize=5,legend_fontsize=5, legend_kws=None,legend_title_fontsize=5, marker_fontsize=4,marker_pad=0.1, linewidth=0.5,axis_format='tiny',alpha=0.7, text_kws=None,**kwargs): """ Plot cluster. Parameters ---------- adata_path : ax : coord_base : cluster_col : palette_path : coding : output : show : figsize : sheet_name : ncol : fontsize : legend_fontsize : int legend fontsize, default 5 legend_kws: dict kwargs passed to ax.legend legend_title_fontsize: int legend title fontsize, default 5 marker_fontsize: int Marker fontsize, default 3 if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code linewidth : float Line width of the legend marker (circle or rectangle), default 0.5 kwargs : dict set text_anno=None to plot clustering without text annotations, coding=True to plot clustering without code annotations, set show_legend=False to remove the legend Returns ------- """ from pandas.api.types import is_categorical_dtype if coord_base.startswith("X_"): coord_base=coord_base.replace('X_','') if sheet_name is None: sheet_name=cluster_col if isinstance(adata_path,str): adata=anndata.read_h5ad(adata_path,backed='r') else: adata=adata_path if not is_categorical_dtype(adata.obs[cluster_col]): adata.obs[cluster_col] = adata.obs[cluster_col].astype('category') if not palette_path is None: if isinstance(palette_path,str): colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict() keys=list(colors.keys()) existed_vals=adata.obs[cluster_col].unique().tolist() for k in existed_vals: if k not in keys: colors[k]='gray' for k in keys: if k not in existed_vals: del colors[k] else: colors=palette_path adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()] else: if f'{cluster_col}_colors' not in adata.uns: sc.pl.embedding(adata,basis=f"X_{coord_base}",color=[cluster_col],show=False) colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])} hue=cluster_col text_anno = cluster_col text_kws = {} if text_kws is None else text_kws text_kws.setdefault("fontsize", fontsize) kwargs.setdefault("hue",hue) kwargs.setdefault("text_anno", text_anno) kwargs.setdefault("text_kws", text_kws) kwargs.setdefault("luminance", 0.65) kwargs.setdefault("dodge_text", False) kwargs.setdefault("axis_format", axis_format) kwargs.setdefault("show_legend", True) kwargs.setdefault("marker_fontsize", marker_fontsize) kwargs.setdefault("marker_pad", marker_pad) kwargs.setdefault("linewidth", linewidth) kwargs.setdefault("alpha", alpha) kwargs["coding"]=coding kwargs["id_marker"]=id_marker legend_kws={} if legend_kws is None else legend_kws default_lgd_kws=dict( fontsize=legend_fontsize, title=cluster_col,title_fontsize=legend_title_fontsize) if not ncol is None: default_lgd_kws['ncol']=ncol for k in default_lgd_kws: legend_kws.setdefault(k, default_lgd_kws[k]) kwargs.setdefault("dodge_kws", { "arrowprops": { "arrowstyle": "->", "fc": 'grey', "ec": "none", "connectionstyle": "angle,angleA=-90,angleB=180,rad=5", }, 'autoalign': 'xy'}) if ax is None: fig, ax = plt.subplots(figsize=figsize, dpi=300) p = categorical_scatter( data=adata[adata.obs[cluster_col].notna(),], ax=ax, coord_base=coord_base, palette=colors,legend_kws=legend_kws, **kwargs) if not output is None: plt.savefig(os.path.expanduser(output)) # transparent=True,bbox_inches='tight',dpi=300 if show: plt.show()
[docs] def plot_gene( adata_path="~/Projects/BG/adata/BG.gene-CGN.h5ad",obs=None, group_col=None,gene='CADM1',query_str=None,title=None, palette_path=None,hue_norm=None, cbar_kws=dict(extendfrac=0.1),axis_format="tiny",scatter_kws={}, obsm=None,coord_base='umap',normalize_per_cell=True, stripplot=False,hypo_score=False,ylim=None, clip_norm_value=10,min_cells=3,cmap='parula',figdir="figures"): """ Plot gene expression in a given region or group on embedding of adata. Parameters ---------- adata_path : str, optional _description_, by default "~/Projects/BG/adata/BG.gene-CGN.h5ad" group_col : str, optional _description_, for example: 'Region', by default None gene : str, optional _description_, by default 'CADM1' query_str : _type_, optional _description_, by default None title : _type_, optional _description_, by default None palette_path : str, optional _description_, by default "~/Projects/BG/obs/HMBA_color_palette.xlsx" obsm : str, optional _description_, by default "~/Projects/BG/clustering/100kb/annotated.adata.h5ad" coord_base : str, optional _description_, by default 'umap' normalize_per_cell : bool, optional _description_, by default True stripplot : bool, optional _description_, by default False clip_norm_value : int, optional _description_, by default 10 min_cells : int, optional _description_, by default 3 cmap: str, optional _description_, by default 'parula' figdir : str, optional _description_, by default "figures" """ # sc.set_figure_params(dpi=100,dpi_save=300,frameon=False) if title is None: if not query_str is None: title=query_str else: title=group_col if not group_col is None else gene if not os.path.exists(figdir): os.makedirs(figdir, exist_ok=True) raw_adata = anndata.read_h5ad(os.path.expanduser(adata_path), backed='r') adata = raw_adata[:, gene].to_memory() raw_adata.file.close() # close the file to save memory if normalize_per_cell: adata = normalize_mc_by_cell( use_adata=adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) is_open=False if not obsm is None: if isinstance(obsm, str): obsm = anndata.read_h5ad(os.path.expanduser(obsm),backed='r') is_open=True assert isinstance(obsm, anndata.AnnData), "obsm should be an anndata object or a path to an h5ad file." keep_cells = list(set(adata.obs_names.tolist()) & set(obsm.obs_names.tolist())) adata = adata[keep_cells, :] adata.obsm = obsm[keep_cells].obsm cur_cols = adata.obs.columns.tolist() for col in obsm.obs.columns.tolist(): if col not in cur_cols: adata.obs[col] = obsm.obs.loc[adata.obs_names, col].tolist() if is_open: obsm.file.close() if not obs is None: if isinstance(obs,str): obs=pd.read_csv(os.path.expanduser(obs), sep='\t',index_col=0) else: obs=obs.copy() else: obs=adata.obs.copy() if not query_str is None: obs = obs.query(query_str) overlapped_cells=list(set(adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] adata=adata[overlapped_cells,:] # type: ignore adata.obs=obs.loc[adata.obs_names.tolist()] print(adata.shape) # read color palette if not group_col is None and not palette_path is None: if os.path.exists(os.path.expanduser(palette_path)): palette_path = os.path.abspath(os.path.expanduser(palette_path)) D = pd.read_excel(palette_path, sheet_name=None, index_col=0) color_palette = D[group_col].Hex.to_dict() else: color_palette = adata.obs.reset_index().loc[:, [group_col, \ palette_path]].drop_duplicates().dropna().set_index(group_col)[ palette_path].to_dict() else: color_palette = None # plot gene on given cordinate base fig, ax = plt.subplots(figsize=(4, 4), dpi=300) output=os.path.join(figdir, f"{title}.{gene}.{coord_base}.pdf") sc.pl.embedding(adata, basis=coord_base, wspace=0.1, color=[gene],use_raw=False, ncols=2, vmin='p5', vmax='p95', frameon=False, show=False,cmap=cmap,ax=ax) colorbar = fig.axes[-1] cur_pos=colorbar.get_position() colorbar.set_position([cur_pos.x0,(1-cur_pos.height/2)/2,cur_pos.width, cur_pos.height / 2]) fig.savefig(output) # transparent=True,bbox_inches='tight',dpi=300 adata.obs[gene]=adata.to_df().loc[adata.obs_names.tolist(), gene].tolist() # print(hue_norm) fig, ax = plt.subplots(figsize=(4, 4), dpi=300) output=os.path.join(figdir, f"{title}.{gene}.{coord_base}1.pdf") continuous_scatter( data=adata, ax=ax,cmap=cmap, hue_norm=hue_norm, cbar_kws=cbar_kws, hue=gene,axis_format=axis_format, text_anno=None, coord_base=coord_base,**scatter_kws) fig.savefig(output) # transparent=True,bbox_inches='tight',dpi=300 if not group_col is None: output=os.path.join(figdir, f"{title}.{group_col}.{coord_base}.pdf") if not os.path.exists(output): # plot embedding colored by group_col if not color_palette is None: use_cells = adata.obs.loc[adata.obs[group_col].isin(list(color_palette.keys()))].index.tolist() else: use_cells = adata.obs.index.tolist() plot_cluster(adata_path=adata,coord_base=coord_base, cluster_col=group_col, coding=False,palette_path=palette_path,ncol=1, output=output,text_anno=None) # boxplot data = adata.to_df() data[group_col] = adata.obs.loc[data.index.tolist(), group_col].tolist() vc = data[group_col].value_counts() N=vc.shape[0] if not color_palette is None: keep_groups = list(set(list(color_palette.keys())) & set(vc[vc >= min_cells].index.tolist())) data = data.loc[data[group_col].isin(keep_groups)] vc = vc.to_dict() order = data.groupby(group_col)[gene].median().sort_values().index.tolist() width = max(5, N*0.5) plt.figure(figsize=(width, 3.5)) if stripplot: ax = sns.stripplot(data=data, jitter=0.4, edgecolor='white', x=group_col, y=gene, palette=color_palette, \ order=order, size=0.5) else: ax = None # ax = sns.boxplot(data=data, x=group_col, y=gene, palette=color_palette, ax=ax, # hue=group_col, # fliersize=0.5, notch=False, showfliers=False, saturation=0.6, order=order) # boxplot are incorrect for some cases when there are many 0, median and lower quartile are often at zero; use violinplot in stead. ax = sns.violinplot(data=data, x=group_col, y=gene, palette=color_palette, ax=ax, # hue=group_col, saturation=0.6, order=order,density_norm='width',cut=0,bw_adjust=0.5) # ax=sns.swarmplot(data=data,palette=color_palette,\ # edgecolor='white',x=group_col,y=gene,\ # order=order) if not ylim is None: ax.set_ybound(ylim) ax.set_xticklabels([f"{label} ({vc[label]})" for label in order]) title=title.replace(' ','.') ax.set_title(title) ax.xaxis.label.set_visible(False) plt.setp(ax.xaxis.get_majorticklabels(), rotation=-45, ha='left') plt.savefig(os.path.join(figdir, f"{title}.{gene}.{group_col}.boxplot.pdf")) return adata
[docs] def plot_genes( adata_path="/home/x-wding2/Projects/BICAN/adata/HMBA_v2/HMBA.Group.downsample_1500.h5ad", query_str=None, obs=None, #"~/Projects/BG/clustering/100kb/annotations.tsv", group_col='Subclass', parent_col=None, modality='RNA', # mc or RNA use_raw=True, # True for RNA expression_cutoff='p5', # for RNA, could be int, median, mean of p5, p95 and so on genes=None, cell_type_order=None, gene_order=None, row_cluster=False, col_cluster=False, cmap='Greens_r', group_legend=False, parent_legend=False, title=None, palette_path=None,#"/home/x-wding2/Projects/BICAN/adata/HMBA_v2/HMBA_color_palette.xlsx" legend_kws=dict(extendfrac=0.1,extend='both',label='Mean mCG'), normalize_per_cell=True, clip_norm_value=10, hypo_score=False, figsize=(10, 3.5), marker='o', plot_kws={},transpose=False, outname="test.pdf"): """ _summary_ Parameters ---------- couldbeint : _type_ _description_ median : _type_ _description_ meanofp5 : _type_ _description_ adata_path : str, optional _description_, by default "/home/x-wding2/Projects/BICAN/adata/HMBA_v2/HMBA.Group.downsample_1500.h5ad" query_str : _type_, optional _description_, by default None obs : _type_, optional _description_, by default None group_col : str, optional _description_, by default 'Subclass' parent_col : str, optional _description_, by default 'Neighborhood' modality : str, optional _description_, by default 'RNA' p95andsoongenes : _type_, optional _description_, by default None cell_type_order : _type_, optional _description_, by default None gene_order : _type_, optional _description_, by default None row_cluster : bool, optional _description_, by default False col_cluster : bool, optional _description_, by default False cmap : str, optional _description_, by default 'Greens_r' group_legend : bool, optional _description_, by default False parent_legend : bool, optional _description_, by default False title : str, optional _description_, by default 'test' palette_path : _type_, optional _description_, by default None obsm : _type_, optional _description_, by default None normalize_per_cell : bool, optional _description_, by default True clip_norm_value : int, optional _description_, by default 10 hypo_score : bool, optional _description_, by default False figsize : tuple, optional _description_, by default (10, 3.5) cmap : str, optional _description_, by default 'Greens_r' marker : str, optional _description_, by default 'o' plot_kws : dict, optional _description_, by default {} outname : str, optional _description_, by default "test.pdf" """ from PyComplexHeatmap import HeatmapAnnotation,anno_label,anno_simple,DotClustermapPlotter assert not genes is None, "Please provide genes to plot." # adata could be single cell level or pseudobulk level (adata.layers['frac'] should be existed) raw_adata = anndata.read_h5ad(os.path.expanduser(adata_path), backed='r') all_vars=set(raw_adata.var_names.tolist()) keep_genes=list(set(all_vars) & set(genes)) # keep_genes=[g for g in all_vars if g in genes] error_genes=[g for g in genes if g not in keep_genes] if len(error_genes)>0: print(f"genes not found in adata: {error_genes}") adata = raw_adata[:, keep_genes].to_memory() # type: ignore if use_raw and not adata.raw is None: adata_raw=adata.raw[:,adata.var_names.tolist()].to_adata() adata.X=adata_raw[adata.obs_names.tolist(),adata.var_names.tolist()].X.copy() # type: ignore del adata_raw raw_adata.file.close() # close the file to save memory if not obs is None: if isinstance(obs,str): obs=pd.read_csv(os.path.expanduser(obs), sep='\t',index_col=0) else: obs=obs.copy() else: obs=adata.obs.copy() if not query_str is None: obs = obs.query(query_str) overlapped_cells=list(set(adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] adata=adata[overlapped_cells,:] # type: ignore if isinstance(group_col,list): group_col1="+".join(group_col) obs[group_col1]=obs.loc[:,group_col].apply(lambda x:'+'.join(x.astype(str).tolist()),axis=1) group_col=group_col1 adata.obs[group_col]=obs.loc[adata.obs_names.tolist(),group_col].tolist() if title is None: if not query_str is None: title=query_str else: title=group_col if not group_col is None else '-'.join(genes) if not parent_col is None and parent_col not in adata.obs.columns.tolist(): adata.obs[parent_col]=obs.loc[adata.obs_names.tolist(),parent_col].tolist() if modality!='RNA' and normalize_per_cell: adata = normalize_mc_by_cell( use_adata=adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) print(adata.shape) # read color palette color_palette={} if not palette_path is None: if os.path.exists(os.path.expanduser(palette_path)): palette_path = os.path.abspath(os.path.expanduser(palette_path)) D = pd.read_excel(palette_path, sheet_name=None, index_col=0) if group_col in D: color_palette[group_col] = D[group_col].Hex.to_dict() else: assert '+' in group_col, f"{group_col} not found in the palette file." for group in group_col.split('+'): assert group in D, f"{group} not found in the palette file." color_palette[group] = D[group].Hex.to_dict() if not parent_col is None: color_palette[parent_col] = D[parent_col].Hex.to_dict() else: color_palette[group_col] = adata.obs.reset_index().loc[:, [group_col, \ palette_path]].drop_duplicates().dropna().set_index(group_col)[ palette_path].to_dict() color_palette[parent_col] = adata.obs.reset_index().loc[:, [parent_col, \ palette_path]].drop_duplicates().dropna().set_index(parent_col)[ palette_path].to_dict() else: color_palette = None data=adata.to_df() # rows are cells or cell types, columns are genes if modality=='RNA' and isinstance(expression_cutoff,str): if expression_cutoff=='median': cutoff=data.stack().median() elif expression_cutoff=='mean': cutoff=data.stack().mean() else: # quantile, such as p5,p95 f=float(expression_cutoff.replace('p','')) cutoff=data.stack().quantile(f/100) expression_cutoff=cutoff data[group_col]=adata.obs.loc[data.index.tolist(),group_col].tolist() if not parent_col is None and parent_col in adata.obs.columns.tolist(): group2parent=adata.obs.loc[:,[group_col,parent_col]].drop_duplicates().set_index(group_col)[parent_col].to_dict() plot_data=data.groupby(group_col).mean().stack().reset_index() plot_data.columns=[group_col,'Gene','Mean'] if 'frac' in adata.layers: D=adata.to_df(layer='frac').stack().to_dict() else: if modality!='RNA': # methylation, cutoff = 1 assert normalize_per_cell==True,"Normalized methylation fraction is required" hypo_frac=data.groupby(group_col).agg(lambda x:x[x< 1].shape[0] / x.shape[0]) # fraction of cells showing hypomethylation for the corresponding genes D=hypo_frac.stack().to_dict() else: # for RNA print(f"Using expression cutoff: {expression_cutoff}") frac=data.groupby(group_col).agg(lambda x:x[x>expression_cutoff].shape[0] / x.shape[0]) # raw count > expression_cutoff means the gene is expressed D=frac.stack().to_dict() plot_data['frac']=plot_data.loc[:,[group_col,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) # plot_data df_cols=pd.DataFrame(list(sorted(adata.obs[group_col].unique().tolist())),columns=[group_col]) if not parent_col is None: df_cols[parent_col]=df_cols[group_col].map(group2parent) df_cols.sort_values([parent_col,group_col],inplace=True) df_cols.index=df_cols[group_col].tolist() if not cell_type_order is None: rows=[ct for ct in cell_type_order if ct in df_cols.index.tolist()] df_cols=df_cols.loc[rows] col_ha_dict={} if '+' in group_col: individual_groups=group_col.split('+') for ig in individual_groups: df_cols[ig]=df_cols[group_col].apply(lambda x:x.split('+')[individual_groups.index(ig)]) group_colors={} for k in df_cols[ig].unique().tolist(): group_colors[k]=color_palette[ig][k] col_ha_dict[ig]=anno_simple(df_cols[ig],colors=group_colors, add_text=False,legend=group_legend,height=3,label=ig) df_cols.dropna(inplace=True) # df_cols.head() if not parent_col is None: parent_colors={} axis=1 if not transpose else 0 # 1 for vertical (col annotation), 0 for horizontal for k in df_cols[parent_col].unique().tolist(): parent_colors[k]=color_palette[parent_col][k] if '+' not in group_col: group_colors={} for k in df_cols[group_col].unique().tolist(): group_colors[k]=color_palette[group_col][k] col_ha=HeatmapAnnotation(axis=axis, label=anno_label(df_cols[group_col], colors=group_colors,merge=True, rotation=45,fontsize=12,arrowprops = dict(visible=False)), group=anno_simple(df_cols[group_col],colors=group_colors, add_text=False,legend=group_legend,height=3,label=group_col), parent=anno_simple(df_cols[parent_col],colors=parent_colors, add_text=False,legend=parent_legend,height=3,label=parent_col), ) else: col_ha_dict[parent_col]=anno_simple(df_cols[parent_col],colors=parent_colors, add_text=False,legend=parent_legend,height=3,label=parent_col) col_ha = HeatmapAnnotation(**col_ha_dict,axis=axis, verbose=0) colnames=False else: axis=1 if not transpose else 0 # 1 for vertical (col annotation), 0 for horizontal if '+' not in group_col: group_colors={} for k in df_cols[group_col].unique().tolist(): group_colors[k]=color_palette[group_col][k] col_ha=HeatmapAnnotation(axis=axis, group=anno_simple(df_cols[group_col],colors=group_colors, add_text=False,legend=group_legend,height=3,label=group_col), ) else: col_ha = HeatmapAnnotation(**col_ha_dict,axis=axis, verbose=0) colnames=True if not transpose: top_annotation=col_ha left_annotation=None x=group_col y='Gene' x_order=df_cols.index.tolist() y_order=gene_order show_colnames=colnames show_rownames=True else: top_annotation=None left_annotation=col_ha y=group_col x='Gene' y_order=df_cols.index.tolist() x_order=gene_order show_rownames=colnames show_colnames=True default_plot_kws=dict( marker=marker,grid=None,legend_gap=8,dot_legend_marker=marker,cmap_legend_kws=legend_kws, row_cluster=row_cluster,col_cluster=col_cluster, row_cluster_method='ward',row_cluster_metric='euclidean', col_cluster_method='ward',col_cluster_metric='euclidean', col_names_side='top',row_names_side='left', show_rownames=show_rownames,show_colnames=show_colnames,row_dendrogram=False, # vmin=0,vmax=1.5, xticklabels_kws={'labelrotation': 45, 'labelcolor': 'blue','labelsize':10,'top':True}, yticklabels_kws={'labelcolor': 'blue','labelsize':10,'left':True}, spines=False, ) for k in default_plot_kws: if k not in plot_kws: plot_kws[k]=default_plot_kws[k] plt.figure(figsize=figsize) cm1 = DotClustermapPlotter( data=plot_data, top_annotation=top_annotation,left_annotation=left_annotation, x_order=x_order,y_order=y_order, x=x,y=y,value='Mean',c='Mean',s='frac', cmap=cmap,verbose=1,**plot_kws, ) for ax in cm1.heatmap_axes.ravel(): ax.grid(axis='both', which='minor', color='grey', linestyle='--',alpha=0.6,zorder=0) if outname is None: outname=f"{title}.pdf" plt.savefig(os.path.expanduser(outname)) plt.show() return plot_data,df_cols,cm1
[docs] def get_genes_mean_frac( adata,obs=None,group_col='Subclass',modality='RNA',layer="mean", use_raw=False,expression_cutoff='p5', genes=None, normalize_per_cell=True,clip_norm_value=10,hypo_score=False, ): assert not genes is None, "Please provide genes to plot." # adata could be single cell level or pseudobulk level (adata.layers['frac'] should be existed) if isinstance(adata,str): adata=anndata.read_h5ad(os.path.expanduser(adata), backed='r') all_vars=set(adata.var_names.tolist()) keep_genes=list(set(all_vars) & set(genes)) # keep_genes=[g for g in all_vars if g in genes] error_genes=[g for g in genes if g not in keep_genes] if len(error_genes)>0: logger.debug(f"genes not found in adata: {error_genes}") use_adata = adata[:, keep_genes].to_memory() # type: ignore if adata.isbacked: adata.file.close() # close the file to save memory if 'mean' not in use_adata.layers: #raw count of single cell level adata # calculate mean and frac for each gene from single cell data if use_raw and not use_adata.raw is None: # use_adata.X=use_adata.raw.X.copy() use_adata_raw=use_adata.raw[:,use_adata.var_names.tolist()].to_adata() use_adata.X=use_adata_raw[use_adata.obs_names.tolist(),use_adata.var_names.tolist()].X.copy() # type: ignore del use_adata_raw if not obs is None: if isinstance(obs,str): sep='\t' if obs.endswith('.tsv') or obs.endswith('.txt') else ',' obs=pd.read_csv(os.path.expanduser(obs), sep=sep,index_col=0) assert isinstance(obs,pd.DataFrame), "obs should be a dataframe or a path to a csv/tsv file." else: obs=use_adata.obs.copy() overlapped_cells=list(set(use_adata.obs_names.tolist()) & set(obs.index.tolist())) obs=obs.loc[overlapped_cells] use_adata=use_adata[overlapped_cells,:] # type: ignore if modality!='RNA' and normalize_per_cell: use_adata = normalize_mc_by_cell( use_adata=use_adata, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=hypo_score) data=use_adata.to_df() # rows are cells or cell types, columns are genes if modality=='RNA' and isinstance(expression_cutoff,str): if expression_cutoff=='median': cutoff=data.stack().median() elif expression_cutoff=='mean': cutoff=data.stack().mean() else: # quantile, such as p5,p95 f=float(expression_cutoff.replace('p','')) cutoff=data.stack().quantile(f/100) expression_cutoff=cutoff data[group_col]=obs.loc[data.index.tolist(),group_col].tolist() # type: ignore plot_data=data.groupby(group_col).mean().stack().reset_index() plot_data.columns=[group_col,'Gene','Mean'] if 'frac' in use_adata.layers: D=use_adata.to_df(layer='frac').stack().to_dict() else: if modality!='RNA': # methylation, cutoff = 1 assert normalize_per_cell==True,"Normalized methylation fraction is required" hypo_frac=data.groupby(group_col).agg(lambda x:x[x< 1].shape[0] / x.shape[0]) # fraction of cells showing hypomethylation for the corresponding genes D=hypo_frac.stack().to_dict() else: # for RNA logger.info(f"Using expression cutoff: {expression_cutoff}") frac=data.groupby(group_col).agg(lambda x:x[x>expression_cutoff].shape[0] / x.shape[0]) # raw count > expression_cutoff means the gene is expressed D=frac.stack().to_dict() plot_data['frac']=plot_data.loc[:,[group_col,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) else: plot_data=use_adata.to_df(layer=layer).stack().reset_index() plot_data.columns=[group_col,'Gene','Mean'] D=use_adata.to_df(layer='frac').stack().to_dict() plot_data['frac']=plot_data.loc[:,[group_col,'Gene']].apply(lambda x:tuple(x.tolist()),axis=1).map(D) return plot_data
[docs] def interactive_dotHeatmap( adata=None,obs=None,genes=None,group_col='Subclass', modality="RNA",title=None,use_raw=False, expression_cutoff='p5',normalize_per_cell=True, clip_norm_value=10, width=900,height=700,gene_order=None,colorscale='greens', vmin='p1',vmax='p99',show=True, reversescale=False,size_min=5,size_max=30, renderer="notebook" ): if not renderer is None: pio.renderers.default = renderer plot_data=get_genes_mean_frac( adata,obs=obs,group_col=group_col,modality=modality, use_raw=use_raw,expression_cutoff=expression_cutoff, genes=genes, normalize_per_cell=normalize_per_cell, clip_norm_value=clip_norm_value,hypo_score=False, ) # columns: [group_col,'Gene','Mean','frac'] # Build a Plotly dot-heatmap using scatter markers on categorical axes. # x: groups (columns), y: genes (rows) x_labels = plot_data[group_col].unique().tolist() if gene_order is None: y_labels = plot_data['Gene'].unique().tolist() else: y_labels = [g for g in gene_order if g in plot_data['Gene'].unique()] # Ensure ordering plot_data['x_cat'] = pd.Categorical(plot_data[group_col], categories=x_labels) plot_data['y_cat'] = pd.Categorical(plot_data['Gene'], categories=y_labels) # marker sizes: scale 'frac' (0-1) to reasonable pixel sizes frac_vals = plot_data['frac'].fillna(0).astype(float) sizes = (frac_vals * (size_max - size_min) + size_min).tolist() # marker colors: use Mean mean_vals = plot_data['Mean'].astype(float).tolist() hover_text = [f"Group: {g}<br>Gene: {ge}<br>Mean: {m:.4g}<br>Frac: {f:.3g}" for g,ge,m,f in zip(plot_data[group_col].tolist(), plot_data['Gene'].tolist(), mean_vals, frac_vals)] vmin_quantile=float(int(vmin.replace('p','')) / 100) vmax_quantile=float(int(vmax.replace('p','')) / 100) marker_dict = dict(size=sizes, color=mean_vals, colorscale=colorscale, showscale=True,colorbar=dict(title='Mean'), reversescale=reversescale, sizemode='area', opacity=0.9, cmin=plot_data['Mean'].quantile(vmin_quantile), cmax=plot_data['Mean'].quantile(vmax_quantile) ) fig = go.Figure() fig.add_trace(go.Scatter( x=plot_data[group_col].tolist(), y=plot_data['Gene'].tolist(), mode='markers', marker=marker_dict, text=hover_text, hoverinfo='text' )) # Layout: categorical axes with explicit ordering fig.update_xaxes(type='category', categoryorder='array', categoryarray=x_labels, tickangle= -45) fig.update_yaxes(type='category', categoryorder='array', categoryarray=list(reversed(y_labels))) if title is None: title=group_col fig.update_layout(title=title or '', xaxis_title=group_col, yaxis_title='Gene', width=width, height=height, plot_bgcolor='white') if show: filename=f"dotHeatmap.{group_col}" show_fig(fig,filename=filename) else: return fig
[docs] def get_boxplot_data(adata,variable,gene,obs=None): assert isinstance(adata,anndata.AnnData) if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if isinstance(obs,str): obs_path = os.path.abspath(os.path.expanduser(obs)) sep='\t' if obs_path.endswith('.tsv') or obs_path.endswith('.txt') else ',' data = pd.read_csv(obs_path, index_col=0,sep=sep) else: assert isinstance(obs,pd.DataFrame) data=obs.copy() overlap_idx=data.index.intersection(use_adata.obs_names) data=data.loc[overlap_idx] use_adata=use_adata[overlap_idx,:] # type: ignore if not gene is None: data[gene]=use_adata.to_df()[gene].tolist() # type: ignore return data.loc[:,[variable,gene]]
[docs] def has_stats(adata): if isinstance(adata,str): adata=anndata.read_h5ad(adata,backed='r') flag=True for k in ['min','q25','q50','q75','max','mean','std']: if not k in adata.layers: flag=False break return flag
[docs] def plot_interactive_boxlot_from_data( adata,obs,variable,gene,palette_path=None, width=1100,height=700,title=None, ): plot_df = get_boxplot_data(adata,variable,gene,obs=obs) # Preserve existing Y-axis extreme filtering logic (remove 1% and 99% extremes) range_y=[plot_df[gene].quantile(0.01), plot_df[gene].quantile(0.99)] # color_discrete_map=get_colors(adata,variable,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=adata,groups=variable) if not color_discrete_map is None: keys=list(color_discrete_map.keys()) # type: ignore for k in keys: if not k in plot_df[variable].unique().tolist(): del color_discrete_map[k] # type: ignore if title is None: title=f"Boxplot: {gene} by {variable}" fig = px.box( plot_df, x=variable, y=gene, color=variable, color_discrete_sequence=px.colors.qualitative.D3, # color palette (professional, unobtrusive) color_discrete_map=color_discrete_map, range_y=range_y, points=False, title=title, template="plotly_white" # keep white background style ) fig.update_xaxes(tickangle=-90, automargin=True) fig.update_traces( line_width=1.2, # thinner lines for a more refined look notched=False # no notch, standard boxplot style ) fig.update_layout( xaxis_title=variable, yaxis_title=gene, legend_title=variable, width=width, height=height ) return fig
[docs] def plot_interacrive_boxplot_from_stats( adata,variable,gene,palette_path=None, title=None,width=1100,height=700): assert isinstance(adata,anndata.AnnData) if adata.isbacked: # type: ignore use_adata=adata[:,gene].to_memory() # type: ignore else: use_adata=adata[:,gene].copy() # type: ignore if adata.isbacked: # type: ignore adata.file.close() # type: ignore stat_keys=['min','q25','q50','q75','max','mean','std'] plot_data=[] for k in stat_keys: df=use_adata.to_df(layer=k)[gene] df.name=k plot_data.append(df) plot_data=pd.concat(plot_data,axis=1) # rows are cell types and columns are stats groups=plot_data.sort_values('q50').index.tolist() plot_data=plot_data.loc[groups] # build figure with one Box per group using precomputed quartiles/fences fig = go.Figure() # optional color mapping # color_discrete_map=get_colors(adata,variable,palette_path=palette_path) color_discrete_map=load_color_palette(palette_path=palette_path,adata=adata,groups=variable) palette = px.colors.qualitative.D3 color = None for group, row in plot_data.iterrows(): i=groups.index(group) q1 = row['q25'] med = row['q50'] q3 = row['q75'] low = row['min'] high = row['max'] mean = row['mean'] std = row['std'] if color_discrete_map is not None and group in color_discrete_map: color = color_discrete_map[group] else: color = palette[i % len(palette)] # Box from precomputed stats (single-element arrays) fig.add_trace( go.Box( x=[group], q1=[q1], median=[med], q3=[q3], lowerfence=[low], upperfence=[high], boxpoints=False, marker=dict(color=color), name=str(group), showlegend=True ) ) # mean as a scatter point with std error bar # fig.add_trace( # go.Scatter( # x=[group], # y=[mean], # mode='markers', # marker=dict(symbol='diamond', size=8, color='black'), # error_y=dict(type='data', array=[std], visible=True), # name='mean', # showlegend=False # ) # ) if title is None: title=f"Boxplot: {gene} by {variable}" fig.update_xaxes(tickangle=-90, automargin=True) fig.update_layout( title=title, xaxis_title=variable, yaxis_title=gene, legend_title=variable, template='plotly_white', width=width, height=height ) return fig
[docs] def interactive_boxplot( adata,variable,gene,obs=None,palette_path=None, title=None,width=1100,height=700,show=True,renderer='notebook'): if not renderer is None: pio.renderers.default = renderer if isinstance(adata,str): adata=anndata.read_h5ad(adata,backed='r') else: assert isinstance(adata,anndata.AnnData) if obs is None: obs=adata.obs.copy() # type: ignore if not has_stats(adata): fig=plot_interactive_boxlot_from_data( adata,obs,variable,gene,palette_path=palette_path, title=title,width=width,height=height ) else: # pseudobulk level with precomputed stats fig=plot_interacrive_boxplot_from_stats( adata,variable,gene,palette_path=palette_path, title=title,width=width,height=height) if show: filename=f"boxplot.{variable}.{gene}" show_fig(fig,filename=filename) return None else: return fig