import pandas as pd
import os, sys
import seaborn as sns
import numpy as np
import matplotlib.pylab as plt
import matplotlib
from matplotlib.legend_handler import HandlerBase
from matplotlib.lines import Line2D
from matplotlib.text import Text
import copy
import warnings
warnings.filterwarnings("ignore")
mm2inch = 1 / 25.4
[docs]
def df2stdout(df):
sys.stdout.write('\t'.join([str(i) for i in df.columns.tolist()]) + '\n')
for i, row in df.iterrows():
try:
sys.stdout.write('\t'.join([str(i) for i in row.fillna('').tolist()]) + '\n')
except:
sys.stdout.close()
[docs]
def serialize(x):
if isinstance(x, pd.DataFrame):
df2stdout(x)
elif not x is None:
print(x)
else:
pass
[docs]
def prepare_color_palette(color_dict=None,outpath="palette.xlsx"):
"""
Generating a .xlsx file including all color palette.
Parameters
----------
colors : dict
A dict of dict, keys are categorical terms, values are HEX color code
Returns
-------
"""
outpath=os.path.expanduser(outpath)
writer = pd.ExcelWriter(outpath)
for key in color_dict:
data = pd.DataFrame.from_dict(color_dict[key], orient='index', columns=['Hex'])
# data.style.background_gradient(cmap='gray_r')
# data.style.applymap(lambda x:'color:'+x if x.startswith('#') else 'color: white')
data.to_excel(writer, sheet_name=key, index=True)
workbook = writer.book
worksheet = writer.sheets[key]
colors = data.Hex.tolist()
for i in range(data.shape[0]):
color = colors[i]
f = workbook.add_format({'bold': True, 'font_color': 'black', 'bg_color': color})
worksheet.write(i + 1, 1, color, f)
width = 20
cell_fmt = workbook.add_format(
{'bold': False, 'font_color': 'black',
# 'bg_color':'green',
'align': 'center', 'valign': 'vcenter'})
# styled = data.style.applymap(lambda val: 'color: %s' % 'red' if val < 0 else 'black').highlight_max()
worksheet.set_column(0, 1, width, cell_fmt)
# worksheet.conditional_format(f'A:{last_col}', {'type': 'no_blanks', 'format': cell_fmt})
writer.close()
[docs]
def mpl_style():
import matplotlib as mpl
mpl.style.use('default')
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['savefig.dpi'] = 300
[docs]
def parse_json(url):
import json
import requests
data=json.loads(requests.get(url).content.decode())
R=[]
for record in data['msg']:
R.append([record['acronym'],record['color_hex_triplet'],record['id'],record['name'],record['parent_structure_id'],record['safe_name'],record['structure_id_path']])
df=pd.DataFrame(R,columns=['acronym','Hex','id','name','parent_structure_id','safe_name','structure_id_path'])
return df
[docs]
def get_brain_region_structure():
"""
https://atlas.brain-map.org/
BICAN: https://atlas.brain-map.org/atlasviewer/ontologies/11.json
HBA: https://atlas.brain-map.org/atlasviewer/ontologies/7.json
Returns
-------
"""
# pip install XlsxWriter
writer = pd.ExcelWriter("AllenBrainRegionStructure.xlsx")
for url,name in zip(['https://atlas.brain-map.org/atlasviewer/ontologies/11.json','https://atlas.brain-map.org/atlasviewer/ontologies/7.json'],['BICAN_Brodmann','HBA_guide']):
print(name)
df=parse_json(url)
df['Hex']='#'+df.Hex.map(str)
df['id']=df['id'].fillna('None').map(str)
id2acronym=df.set_index('id').acronym.to_dict()
df.insert(0,'Parent',df.parent_structure_id.fillna(0).map(int).map(str).map(id2acronym))
df.parent_structure_id=df.parent_structure_id.map(str)
df.structure_id_path=df.structure_id_path.apply(lambda x:x[1:-1])
df['structure_path']=df.structure_id_path.apply(lambda x:'//'.join([id2acronym[p] for p in x.split('/') if p!='']))
# data.style.background_gradient(cmap='gray_r')
# data.style.applymap(lambda x:'color:'+x if x.startswith('#') else 'color: white')
df=df.loc[:,['Parent','acronym','Hex','name','safe_name','id','parent_structure_id','structure_path','structure_id_path']]
df.to_excel(writer,sheet_name=name,index=False)
workbook = writer.book
worksheet = writer.sheets[name]
colors=df.Hex.tolist()
col_idx=df.columns.tolist().index('Hex')
for i in range(df.shape[0]):
color = colors[i]
f = workbook.add_format({'bold': True, 'font_color': 'black', 'bg_color': color})
worksheet.write(i+1, col_idx, color,f) #worksheet.write(row, col, *args), row and col are 0-based
# width=20
# cell_fmt = workbook.add_format(
# {'bold': False,'font_color': 'black',
# # 'bg_color':'green',
# 'align': 'center', 'valign': 'vcenter'})
# # styled = data.style.applymap(lambda val: 'color: %s' % 'red' if val < 0 else 'black').highlight_max()
# worksheet.set_column(0,col_idx,width,cell_fmt) #first_col,last_col
# worksheet.conditional_format(f'A:{last_col}', {'type': 'no_blanks', 'format': cell_fmt})
writer.close()
df_bican=pd.read_excel("Jon.xlsx")
df_bican.Acronym=df_bican.Acronym.apply(lambda x:x.split('(')[0].strip())
regions=df_bican.Acronym.tolist()
df=pd.read_excel("AllenBrainRegionStructure.xlsx",
sheet_name="BICAN_Brodmann")
# df['Keep']=df.structure_path.apply(lambda x:[r for r in x.split('//') if r in regions])
# df=df.loc[df.Keep.apply(len) > 0]
# df.drop('Keep',axis=1,inplace=True)
df=df.loc[df.acronym.isin(regions)]
df.parent_structure_id=df.parent_structure_id.map(int)
writer = pd.ExcelWriter("BICAN_regions.xlsx")
df.to_excel(writer, sheet_name='BICAN', index=False)
workbook = writer.book
worksheet = writer.sheets['BICAN']
colors=df.Hex.tolist()
col_idx=df.columns.tolist().index('Hex')
for i in range(df.shape[0]):
color = colors[i]
f = workbook.add_format({'bold': True, 'font_color': 'black', 'bg_color': color})
worksheet.write(i+1, col_idx, color,f)
writer.close()
[docs]
def read_google_sheet(url=None,**kwargs):
assert not url is None
# url="https://docs.google.com/spreadsheets/d/12H3p2F_qrcQ3ymF614VRzU_6vcVTsRca0uOIBQJXVaU/edit?gid=1969763406#gid=1969763406"
Id=url.split('/d/')[1].split('/')[0]
gid=url.split('?gid=')[1].split('#gid=')[0]
df = pd.read_csv(f"https://docs.google.com/spreadsheets/d/{Id}/export?format=tsv&id={Id}&gid={gid}",
sep='\t',**kwargs)
return df
[docs]
def despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False):
"""
Remove the top and right spines from plot(s).
Parameters
----------
fig : matplotlib figure, optional
Figure to despine all axes of, defaults to the current figure.
ax : matplotlib axes, optional
Specific axes object to despine. Ignored if fig is provided.
top, right, left, bottom : boolean, optional
If True, remove that spine.
Returns
-------
None
"""
if fig is None and ax is None:
axes = plt.gcf().axes
elif fig is not None:
axes = fig.axes
elif ax is not None:
axes = [ax]
for ax_i in axes:
for side in ["top", "right", "left", "bottom"]:
is_visible = not locals()[side]
ax_i.spines[side].set_visible(is_visible)
if left and not right: # remove left, keep right
maj_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.majorTicks)
min_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.minorTicks)
ax_i.yaxis.set_ticks_position("right")
for t in ax_i.yaxis.majorTicks:
t.tick2line.set_visible(maj_on)
for t in ax_i.yaxis.minorTicks:
t.tick2line.set_visible(min_on)
if bottom and not top:
maj_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.majorTicks)
min_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.minorTicks)
ax_i.xaxis.set_ticks_position("top")
for t in ax_i.xaxis.majorTicks:
t.tick2line.set_visible(maj_on)
for t in ax_i.xaxis.minorTicks:
t.tick2line.set_visible(min_on)
def _make_tiny_axis_label(ax, x, y, arrow_kws=None, fontsize=5):
# This function assume coord is [0, 1].
# clean ax axises
ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
despine(ax=ax, left=True, bottom=True)
_arrow_kws = {"width": 0.003, "linewidth": 0, "color": "black"}
if arrow_kws is not None:
_arrow_kws.update(arrow_kws)
ax.arrow(0.06, 0.06, 0, 0.06, **_arrow_kws, transform=ax.transAxes)
ax.arrow(0.06, 0.06, 0.06, 0, **_arrow_kws, transform=ax.transAxes)
ax.text(
0.06,
0.03,
x.upper().replace("_", " "),
fontdict={"fontsize": fontsize, "horizontalalignment": "left", "verticalalignment": "center"},
transform=ax.transAxes,
)
ax.text(
0.03,
0.06,
y.upper().replace("_", " "),
fontdict={
"fontsize": fontsize,
"rotation": 90,
"rotation_mode": "anchor",
"horizontalalignment": "left",
"verticalalignment": "center",
},
transform=ax.transAxes,
)
return
[docs]
def zoom_min_max(vmin, vmax, scale):
"""Zoom min and max value."""
width = vmax - vmin
width_zoomed = width * scale
delta_value = (width_zoomed - width) / 2
return vmin - delta_value, vmax + delta_value
[docs]
def zoom_ax(ax, zoom_scale, on="both"):
"""Zoom ax on both x and y-axis."""
on = on.lower()
xlim = ax.get_xlim()
xlim_zoomed = zoom_min_max(vmin=xlim[0], vmax=xlim[1],scale=zoom_scale)
ylim = ax.get_ylim()
ylim_zoomed = zoom_min_max(vmin=ylim[0], vmax=ylim[1],scale=zoom_scale)
if (on == "both") or ("x" in on):
ax.set_xlim(xlim_zoomed)
if (on == "both") or ("y" in on):
ax.set_ylim(ylim_zoomed)
def _extract_coords(data, coord_base, x, y):
import xarray as xr
import anndata
if (x is not None) and (y is not None):
pass
else:
x = f"{coord_base}_0"
y = f"{coord_base}_1"
if isinstance(data, anndata.AnnData):
adata = data
_data = pd.DataFrame(
{
"x": adata.obsm[f"X_{coord_base}"][:, 0],
"y": adata.obsm[f"X_{coord_base}"][:, 1],
},
index=adata.obs_names,
)
elif isinstance(data, xr.Dataset):
ds = data
if coord_base not in ds.dims:
raise KeyError(f"xr.Dataset do not contain {coord_base} dim")
data_var = {i for i in ds.data_vars.keys() if i.startswith(coord_base)}.pop()
_data = pd.DataFrame(
{
"x": ds[data_var].sel({coord_base: f"{coord_base}_0"}).to_pandas(),
"y": ds[data_var].sel({coord_base: f"{coord_base}_1"}).to_pandas(),
}
)
else:
if (x not in data.columns) or (y not in data.columns):
raise KeyError(f"{x} or {y} not found in columns.")
_data = pd.DataFrame({"x": data[x], "y": data[y]})
return _data, x, y
def _density_based_sample(data: pd.DataFrame, coords: list, portion=None, size=None, seed=None):
"""Down sample data based on density, to prevent overplot in dense region and decrease plotting time."""
from sklearn.neighbors import LocalOutlierFactor
clf = LocalOutlierFactor(
n_neighbors=20,
algorithm="auto",
leaf_size=30,
metric="minkowski",
p=2,
metric_params=None,
contamination=0.1,
)
# coords should already exist in data, get them by column names list
data_coords = data[coords]
clf.fit(data_coords)
# original score is negative, the larger the denser
density_score = clf.negative_outlier_factor_
delta = density_score.max() - density_score.min()
# density score to probability: the denser the less probability to be picked up
probability_score = 1 - (density_score - density_score.min()) / delta
probability_score = np.sqrt(probability_score)
probability_score = probability_score / probability_score.sum()
if size is not None:
pass
elif portion is not None:
size = int(data_coords.index.size * portion)
else:
raise ValueError("Either portion or size should be provided.")
if seed is not None:
np.random.seed(seed)
selected_cell_index = np.random.choice(
data_coords.index, size=size, replace=False, p=probability_score
) # choice data based on density weights
# return the down sampled data
return data.reindex(selected_cell_index)
def _auto_size(ax, n_dots):
"""Auto determine dot size based on ax size and n dots"""
bbox = ax.get_window_extent().transformed(ax.get_figure().dpi_scale_trans.inverted())
scale = bbox.width * bbox.height / 14.6 # 14.6 is a 5*5 fig I used to estimate
n = n_dots / scale # larger figure means data look sparser
if n < 500:
s = 14 - n / 100
elif n < 1500:
s = 7
elif n < 3000:
s = 5
elif n < 8000:
s = 3
elif n < 15000:
s = 2
elif n < 30000:
s = 1.5
elif n < 50000:
s = 1
elif n < 80000:
s = 0.8
elif n < 150000:
s = 0.6
elif n < 300000:
s = 0.5
elif n < 500000:
s = 0.4
elif n < 800000:
s = 0.3
elif n < 1000000:
s = 0.2
elif n < 2000000:
s = 0.1
elif n < 3000000:
s = 0.07
elif n < 4000000:
s = 0.05
elif n < 5000000:
s = 0.03
else:
s = 0.02
return s
def _take_data_series(data, k):
import xarray as xr
import anndata
if isinstance(data, (xr.Dataset, xr.DataArray)):
_value = data[k].to_pandas()
elif isinstance(data, anndata.AnnData):
_value = data.obs[k].copy()
else:
_value = data[k].copy()
return _value
[docs]
def level_one_palette(name_list, order=None, palette="auto"):
name_set = set(name_list.dropna())
if palette == "auto":
if len(name_set) < 10:
palette = "tab10"
elif len(name_set) < 20:
palette = "tab20"
else:
palette = "rainbow"
if order is None:
try:
order = sorted(name_set)
except TypeError:
# name set contains multiple dtype (e.g., str and np.NaN)
# do not sort
order = list(name_set)
else:
if (set(order) != name_set) or (len(order) != len(name_set)):
raise ValueError("Order is not equal to set(name_list).")
n = len(order)
colors = sns.color_palette(palette, n)
color_palette = {}
for name, color in zip(order, colors):
color_palette[name] = color
return color_palette
def _calculate_luminance(color):
"""
Calculate the relative luminance of a color according to W3C standards
Parameters
----------
color : matplotlib color or sequence of matplotlib colors
Hex code, rgb-tuple, or html color name.
Returns
-------
luminance : float(s) between 0 and 1
"""
rgb = matplotlib.colors.colorConverter.to_rgba_array(color)[:, :3]
rgb = np.where(rgb <= 0.03928, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4)
lum = rgb.dot([0.2126, 0.7152, 0.0722])
try:
return lum.item()
except ValueError:
return lum
def _text_anno_scatter(
data: pd.DataFrame,
ax,
x: str,
y: str,
dodge_text=False,
anno_col="text_anno",
text_kws=None,
text_transform=None,
dodge_kws=None,
luminance=0.48
):
"""Add text annotation to a scatter plot."""
import copy
# prepare kws
text_kws={} if text_kws is None else text_kws
text_kws.setdefault("fontsize",5)
text_kws.setdefault("fontweight","black")
text_kws.setdefault("ha","center") #horizontalalignment
text_kws.setdefault("va","center") #verticalalignment
text_kws.setdefault("color","black") #c
bbox=dict(boxstyle='round',edgecolor=(0.5, 0.5, 0.5, 0.2),fill=False,
facecolor=(0.8, 0.8, 0.8, 0.2),alpha=1,linewidth=0.5)
text_kws.setdefault("bbox",bbox)
for key in bbox:
if key not in text_kws['bbox']:
text_kws['bbox'][key]=bbox[key]
# plot each text
text_list = []
for text, sub_df in data.groupby(anno_col):
if text_transform is None:
text = str(text)
else:
text = text_transform(text)
if text.lower() in ["", "nan"]:
continue
_x, _y = sub_df[[x, y]].median()
use_text_kws=copy.deepcopy(text_kws) #text_kws.copy()
if isinstance(text_kws['bbox']['facecolor'],dict):
use_text_kws['bbox']['facecolor']=text_kws['bbox']['facecolor'].get(text,'gray')
if isinstance(text_kws['color'],dict):
use_color=text_kws['color'].get(text,'black')
use_text_kws['color']=use_color
if not luminance is None and not use_text_kws['bbox']['facecolor'] is None:
lum = _calculate_luminance(use_text_kws['bbox']['facecolor'])
if lum > luminance:
use_text_kws['color']='black'
use_text_kws['bbox']['edgecolor']='black'
text = ax.text(
_x,
_y,
text,
**use_text_kws
)
text_list.append(text)
if dodge_text:
try:
from adjustText import adjust_text
_dodge_parms = {
"force_points": (0.02, 0.05),
"arrowprops": {
"arrowstyle": "->",
"fc": "black",
"ec": "none",
"connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
},
"autoalign": "xy",
}
if dodge_kws is not None:
_dodge_parms.update(dodge_kws)
adjust_text(text_list, x=data["x"], y=data["y"], **_dodge_parms)
except ModuleNotFoundError:
print("Install adjustText package to dodge text, see its github page for help")
return
[docs]
def tight_hue_range(hue_data, portion):
"""Automatic select a SMALLEST data range that covers [portion] of the data."""
hue_data = hue_data[np.isfinite(hue_data)]
hue_quantiles = hue_data.quantile(q=np.arange(0, 1, 0.01))
min_window_right = (
hue_quantiles.rolling(window=int(portion * 100)).apply(lambda i: i.max() - i.min(), raw=True).idxmin()
)
min_window_left = max(0, min_window_right - portion)
vmin, vmax = tuple(hue_data.quantile(q=[min_window_left, min_window_right]))
if np.isfinite(vmin):
vmin = max(hue_data.min(), vmin)
else:
vmin = hue_data.min()
if np.isfinite(vmax):
vmax = min(hue_data.max(), vmax)
else:
vmax = hue_data.max()
return vmin, vmax
[docs]
def density_contour(
ax,
data,
x,
y,
groupby=None,
c="lightgray",
single_contour_pad=1,
linewidth=1,
palette=None,
):
from sklearn.neighbors import LocalOutlierFactor
_data = data.copy()
if groupby is not None:
if isinstance(groupby, str):
_data["groupby"] = data[groupby]
else:
_data["groupby"] = groupby
else:
_data["groupby"] = "one group"
_contour_kws = {"linewidths": linewidth, "levels": (-single_contour_pad,), "linestyles": "dashed"}
_lof_kws = {"n_neighbors": 25, "novelty": True, "contamination": "auto"}
xmin, ymin = _data[[x, y]].min()
xmax, ymax = _data[[x, y]].max()
xmin, xmax = zoom_min_max(xmin, xmax, 1.2)
ymin, ymax = zoom_min_max(ymin, ymax, 1.2)
for group, sub_data in _data[[x, y, "groupby"]].groupby("groupby"):
xx, yy = np.meshgrid(np.linspace(xmin, xmax, 500), np.linspace(ymin, ymax, 500))
clf = LocalOutlierFactor(**_lof_kws)
clf.fit(sub_data.iloc[:, :2].values)
z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
z = z.reshape(xx.shape)
if palette is None:
_color = c
else:
_color = palette[group] if group in palette else c
# plot contour line(s)
ax.contour(xx, yy, z, colors=_color, **_contour_kws)
return
[docs]
def plot_color_dict_legend(
D, ax=None, title=None, color_text=True,
kws=None,luminance=0.5
):
"""
plot legned for color dict
Parameters
----------
D: a dict, key is categorical variable, values are colors.
ax: axes to plot the legend.
title: title of legend.
color_text: whether to change the color of text based on the color in D.
label_side: right of left.
kws: kws passed to plt.legend.
Returns
-------
ax.legend
"""
import matplotlib.patches as mpatches
if ax is None:
ax = plt.gca()
lgd_kws = kws.copy() if not kws is None else {} # bbox_to_anchor=(x,-0.05)
lgd_kws.setdefault("frameon", True)
lgd_kws.setdefault("ncol", 1)
lgd_kws["loc"] = "upper left"
lgd_kws.setdefault("borderpad", 0.1 * mm2inch * 72) # 0.1mm
lgd_kws.setdefault("markerscale", 1)
lgd_kws.setdefault("handleheight", 0.5) # font size, units is points
lgd_kws.setdefault("handlelength", 1) # font size, units is points
lgd_kws.setdefault(
"borderaxespad", 0.1
) # The pad between the axes and legend border, in font-size units.
lgd_kws.setdefault(
"handletextpad", 0.4
) # The pad between the legend handle and text, in font-size units.
lgd_kws.setdefault(
"labelspacing", 0.15
) # gap height between two Patches, 0.05*mm2inch*72
lgd_kws.setdefault("columnspacing", 0.5)
# lgd_kws["bbox_transform"] = ax.figure.transFigure
lgd_kws.setdefault("bbox_to_anchor", (1, 1))
lgd_kws.setdefault("title", title)
lgd_kws.setdefault("markerfirst", True)
l = [
mpatches.Patch(color=c, label=l) for l, c in D.items()
] # kws:?mpatches.Patch; rasterized=True
ms = lgd_kws.pop("markersize", 10)
L = ax.legend(handles=l, **lgd_kws)
L._legend_box.align = 'center'
L.get_title().set_ha('center')
if color_text:
for text in L.get_texts():
try:
lum = _calculate_luminance(D[text.get_text()])
if luminance is None:
text_color = "black"
else:
text_color = "black" if lum > luminance else D[text.get_text()]
text.set_color(text_color)
except:
pass
# ax.add_artist(L)
ax.grid(False)
return L
[docs]
def plot_marker_legend(
color_dict=None, ax=None, title=None, color_text=True,
marker='o',kws=None,luminance=0.5
):
"""
plot legned for different marker
Parameters
----------
D: a dict, key is categorical variable, values are marker.
ax: axes to plot the legend.
title: title of legend.
color_text: whether to change the color of text based on the color in D.
label_side: right of left.
kws: kws passed to plt.legend.
Returns
-------
ax.legend
"""
if ax is None:
ax = plt.gca()
lgd_kws = kws.copy() if not kws is None else {} # bbox_to_anchor=(x,-0.05)
lgd_kws.setdefault("frameon", True)
lgd_kws.setdefault("ncol", 1)
lgd_kws["loc"] = "upper left"
# lgd_kws["bbox_transform"] = ax.figure.transFigure
lgd_kws.setdefault("borderpad", 0.2 * mm2inch * 72) # 0.1mm
# lgd_kws.setdefault("markerscale", 1)
lgd_kws.setdefault("handleheight", 0.5) # font size, units is points
lgd_kws.setdefault("handlelength", 1) # font size, units is points
lgd_kws.setdefault(
"borderaxespad", 0.1
) # The pad between the axes and legend border, in font-size units.
lgd_kws.setdefault(
"handletextpad", 0.4 #0.2 * mm2inch * 72
) # The pad between the legend handle and text, in font-size units.
lgd_kws.setdefault(
"labelspacing", 0.15
) # gap height between two Patches, 0.05*mm2inch*72
lgd_kws.setdefault("columnspacing", 0.5)
lgd_kws.setdefault("bbox_to_anchor", (1, 1))
lgd_kws.setdefault("title", title)
lgd_kws.setdefault("markerfirst", True)
ms = lgd_kws.pop("markersize", 10)
L = [
Line2D(
[0],
[0],
color=color,
marker=marker,
linestyle="None",
markersize=ms,
label=l,
)
for l,color in color_dict.items()
]
L = ax.legend(handles=L, **lgd_kws)
ax.figure.canvas.draw()
L._legend_box.align = 'center'
L.get_title().set_ha('center')
if color_text:
for text in L.get_texts():
try:
lum = _calculate_luminance(color_dict[text.get_text()])
if luminance is None:
text_color = "black"
else:
text_color = "black" if lum > luminance else color_dict[text.get_text()]
text.set_color(text_color)
except:
pass
# ax.add_artist(lgd)
ax.grid(False)
return L
# Custom handler for legend: circle text as marker + label
[docs]
class TextWithCircleHandler(HandlerBase):
def __init__(self, marker_text='', label_text='',
text_kws={}, **kwargs):
HandlerBase.__init__(self, **kwargs)
self.marker_text = marker_text
self.text_kws=text_kws
[docs]
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
# Marker (number with circle)
self.text_kws.setdefault("fontsize",fontsize)
# print(self.text_kws)
shift=2 * self.text_kws['fontsize'] * 0.65 / 72 / mm2inch
circ_text = Text(
xdescent + legend.borderaxespad + shift, height / 2,
self.marker_text,
**self.text_kws
)
return [circ_text]
[docs]
def plot_text_legend(color_dict, code2label, ax=None, title=None, color_text=True,
boxstyle='Circle',marker_pad=0.1,legend_kws=None,marker_fontsize=4,
text_kws=None,alpha=0.7,luminance=0.5):
import copy
# print(color_dict)
lgd_kws = legend_kws.copy() if not legend_kws is None else {} # bbox_to_anchor=(x,-0.05)
lgd_kws.setdefault("frameon", True)
lgd_kws.setdefault("ncol", 1)
lgd_kws["loc"] = "upper left"
# lgd_kws["bbox_transform"] = ax.figure.transFigure
lgd_kws.setdefault("borderpad", 0.1 * mm2inch * 72) # 0.1mm
# lgd_kws.setdefault("markerscale", 1)
lgd_kws.setdefault("handleheight", 0.5) # font size, units is points
lgd_kws.setdefault("handlelength", 1) # font size, units is points
lgd_kws.setdefault(
"borderaxespad", 0.5
) # The pad between the axes and legend border, in font-size units.
lgd_kws.setdefault(
"handletextpad", 0.4
) # The pad between the legend handle and text, in font-size units.
lgd_kws.setdefault(
"labelspacing", 0.2
) # gap height between two row of legend, 0.05*mm2inch*72
lgd_kws.setdefault("columnspacing", 0.5)
lgd_kws.setdefault("bbox_to_anchor", (1, 1))
lgd_kws.setdefault("title", title)
lgd_kws.setdefault("markerfirst", True)
ms = lgd_kws.pop("markersize", 10)
# text_kws
if text_kws is None:
text_kws={}
default_marker_text_kws=dict(
bbox=dict(boxstyle=f"{boxstyle},pad={marker_pad}", #Square, Circle, Round
edgecolor='black',linewidth=0.4,
fill=True,facecolor='white',alpha=alpha),
horizontalalignment='center', verticalalignment='center',
fontsize=marker_fontsize,color='black')
for k in default_marker_text_kws:
if k == 'bbox':
if k not in text_kws:
text_kws['bbox']=default_marker_text_kws[k]
else:
for k1 in default_marker_text_kws['bbox'].keys():
if k1 not in text_kws['bbox']:
text_kws['bbox'].setdefault(k1,default_marker_text_kws['bbox'][k1])
if k not in text_kws:
text_kws.setdefault(k,default_marker_text_kws[k])
# Create handles and handlers
handles = []
handler_map = {}
for code in sorted([int(i) for i in code2label.keys()]):
label=code2label[code]
code_text=str(code)
handle = Line2D([], [], linestyle=None,
label=label)
handles.append(handle)
color=color_dict.get(label, 'black')
text_kws1= copy.deepcopy(text_kws)
text_kws1['bbox']['facecolor']=color
lum = _calculate_luminance(color)
if lum <= 0.1: # for black-like color, use white marker text
text_kws1['color']='white'
# print(bbox)
handler_map[handle] = TextWithCircleHandler(
marker_text=code_text, label_text=label,
text_kws=text_kws1,
)
# Draw custom legend
L=ax.legend(handles=handles, handler_map=handler_map,
**lgd_kws)
L._legend_box.align = 'center'
L.get_title().set_ha('center')
ax.figure.canvas.draw()
if color_text:
for text in L.get_texts():
try:
lum = _calculate_luminance(color_dict[text.get_text()])
if luminance is None:
text_color = "black"
else:
text_color = "black" if lum > luminance else color_dict[text.get_text()]
text.set_color(text_color)
except:
pass
# ax.add_artist(lgd)
# ax.grid(False)
[docs]
def plot_cmap_legend(
cax=None, ax=None, cmap="turbo", label=None, kws=None,
labelsize=6, linewidth=0.5,ticklabel_size=4,
):
"""
Plot legend for cmap.
Parameters
----------
cax : Axes into which the colorbar will be drawn.
ax : axes to anchor.
cmap : turbo, hsv, Set1, Dark2, Paired, Accent,tab20,exp1,exp2,meth1,meth2
label : title for legend.
kws : dict
kws passed to plt.colorbar (matplotlib.figure.Figure.colorbar).
Returns
-------
cbar: axes of legend
"""
label = "" if label is None else label
cbar_kws = {} if kws is None else kws.copy()
cbar_kws.setdefault("label", label)
# cbar_kws.setdefault("aspect",3)
cbar_kws.setdefault("orientation", "vertical")
# cbar_kws.setdefault("use_gridspec", True)
# cbar_kws.setdefault("location", "bottom")
cbar_kws.setdefault("fraction", 1)
cbar_kws.setdefault("shrink", 1)
cbar_kws.setdefault("pad", 0)
cbar_kws.setdefault("extend", 'both')
cbar_kws.setdefault("extendfrac", 0.1)
# print(cbar_kws,kws)
# print(type(cax))
vmax = cbar_kws.pop("vmax", 1)
vmin = cbar_kws.pop("vmin", 0)
# print(vmin,vmax,'vmax,vmin')
cax.set_ylim([vmin, vmax])
# print(cax.get_ylim())
vcenter= (vmax + vmin) / 2
center=cbar_kws.pop("center",None)
if center is None:
center=vcenter
m = plt.cm.ScalarMappable(
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap
)
else:
m = plt.cm.ScalarMappable(
norm=matplotlib.colors.TwoSlopeNorm(center,vmin=vmin, vmax=vmax), cmap=cmap
)
cbar_kws.setdefault("ticks", [vmin, center, vmax])
cax.yaxis.set_label_position('right')
cax.yaxis.set_ticks_position('right')
cbar = ax.figure.colorbar(m, cax=cax, **cbar_kws) # use_gridspec=True
cbar.ax.tick_params(labelsize=ticklabel_size, size=ticklabel_size,width=linewidth) # size is for ticks, labelsize is for the number on ticks (ticklabels)
cbar.ax.yaxis.label.set_fontsize(labelsize) # colorbar title fontsize
cbar.ax.grid(False)
return cbar
[docs]
def normalize_mc_by_cell(
use_adata,normalize_per_cell=True,
clip_norm_value=10,verbose=1,hypo_score=False):
from scipy.sparse import issparse
normalized_flag = use_adata.uns.get('normalize_per_cell',False)
if normalize_per_cell and not normalized_flag: # divide frac by prior mean (determined by alpha and beta) for each cell
# get normalized X
cols = use_adata.obs.columns.tolist()
if 'prior_mean' in cols:
if verbose > 0:
print("Normalizing cell level fraction by alpha and beta (prior_mean)")
na_sum = use_adata.to_df().isna().sum().sum()
if na_sum > 0:
D = use_adata.obs.prior_mean.to_dict()
if not hypo_score:
use_adata.X = use_adata.to_df().apply(lambda x:x.fillna(D[x.name]) / D[x.name],axis=1).values
else: # hypo score, the larger the value, the more hypomethylated
use_adata.X = use_adata.to_df().apply(lambda x:D[x.name] / x.fillna(D[x.name]),axis=1).values
else:
if not hypo_score: # the smaller the value, the lower of the methylation fraction
use_adata.X = use_adata.X / use_adata.obs.prior_mean.values[:, None] # range = [0,1,10]
else: # hypo-score
use_adata.X = use_adata.obs.prior_mean.values[:, None] / use_adata.X
if not clip_norm_value is None:
if issparse(use_adata.X):
X=use_adata.X.toarray()
else:
X=use_adata.X
use_adata.X = np.clip(X, None, clip_norm_value)
use_adata.uns['normalize_per_cell'] = True
else:
if verbose > 0:
print("'prior_mean' not found in obs")
elif normalize_per_cell and normalized_flag:
logger.info("Input adata is already normalized, skip normalize_per_cell !")
else:
pass
return use_adata
[docs]
def parse_gtf(gtf="gencode.v43.annotation.gtf",outfile=None):
df=pd.read_csv(os.path.expanduser(gtf),sep='\t',header=None,
comment="#",usecols=[0,2,3,4,6,8],
names=['chrom','record_type','beg','end','strand','information'])
cols=['gene_id','gene_type','gene_name']
def parse_info(x):
x=x.replace('"','')
D={}
for item in x.strip().rstrip(';').split(';'):
k,v=item.strip().split(' ')
D[k.strip()]=v.strip()
return D
df['info_dict']=df.information.apply(parse_info)
for col in cols:
df[col]=df.info_dict.apply(lambda x:x.get(col,''))
df=df.loc[:,['chrom','beg','end','gene_name','gene_id','strand','gene_type']].drop_duplicates()
if outfile is None:
return df # 'chrom','start','end','gene_symbol','strand','gene_type'
else:
df.to_csv(os.path.expanduser(outfile),sep='\t',index=False)