# AUTOGENERATED! DO NOT EDIT! File to edit: StaticPlots.ipynb (unless otherwise specified).
__all__ = ['modify_axes', 'init_figure', 'plot_bands', 'quick_bplot', 'add_text', 'add_legend', 'add_colorbar',
'color_wheel', 'get_pros_data', 'make_line_collection', 'plot_collection', 'quick_rgb_lines',
'quick_color_lines', 'select_pdos', 'collect_dos', 'quick_dos_lines', 'plt_to_html', 'plot_potential']
# Cell
import os
import numpy as np
from io import BytesIO
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap as LSC
from matplotlib.collections import LineCollection
import pivotpy.vr_parser as vp
import pivotpy.g_utils as gu
from IPython import get_ipython
from IPython.display import HTML, set_matplotlib_formats #HTML for plt_to_html
# print SVG in ipython
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell' or shell == 'Shell': # Shell for colab.
set_matplotlib_formats('svg')
except: pass # Not in terminal
# Gloabal settings matplotlib
mpl.rcParams['axes.linewidth'] = 0.4 #set the value globally
mpl.rcParams['font.serif'] = "STIXGeneral"
mpl.rcParams['font.family'] = "serif"
mpl.rcParams['mathtext.fontset'] = "stix"
# Cell
def modify_axes(ax=None,xticks=[],xt_labels=[],xlim=[],\
yticks=[],yt_labels=[],ylim=[],xlabel=None,ylabel=None,\
vlines=True,zeroline=True):
"""
- Returns None, applies given settings on axes. Prefered to use before other plotting.
- **Parameters**
- ax : Matplotlib axes object.
- (x,y)ticks : List of positions on (x,y axes).
- (xt,yt)_labels : List of labels on (x,y) ticks points.
- (x,y)lim : [min, max] of (x,y) axes.
- (x,y)label : axes labels.
- vlines : If true, drawn when `ylim` is not empty.
- zeroline : If True, drawn when `xlim` is not empty.
"""
if ax==None:
raise ValueError("Matplotlib axes (ax) is not given.")
else:
if(xticks):
ax.set_xticks(xticks)
ax.set_xticklabels(xt_labels)
if(yticks):
ax.set_yticks(yticks)
ax.set_yticklabels(yt_labels)
if(xlim):
ax.set_xlim(xlim)
if(zeroline==True):
ax.hlines(0,min(xlim),max(xlim),color=(0,0,0,0.6), linestyle='dashed',lw=0.3)
if(ylim):
ax.set_ylim(ylim)
if(vlines==True):
[ax.vlines(tick,min(ylim),max(ylim),color=(0,0,0,0.6), linestyle='dashed',lw=0.3) for tick in xticks]
if(xlabel!=None):
ax.set_xlabel(xlabel)
if(ylabel!=None):
ax.set_ylabel(ylabel)
ax.tick_params(direction='in', top=True,bottom=True,left=True,right=True,length=4, width=0.3, grid_alpha=0.8)
return None
# Cell
# Cell
def plot_bands(ax=None,kpath=None,bands=None,showlegend=False,E_Fermi=None,\
color1=(0,0,0.8),style1='solid',lw1=0.7,color2=(0.8,0,0),style2='dashed',lw2=0.7):
"""
- Returns axes object and plot on which all matplotlib allowed actions could be performed.
- **Parameters**
- ax : Matplotlib axes object, if not given, one is created.
- kpath : 1D array from `get_kpts`().kpath or `export_vasprun`().kpath.
- bands : Dictionary Object from `get_evals` or `export_vasprun`().bands.
- showlegend : Boolean, default is False, if true, gives legend for spin-polarized calculations.
- E_Fermi : If not given, automatically picked from bands object.
- **kwargs : lines color,width and style to distinguish spin Up and Down.
- **Returns**
- ax : matplotlib axes object with plotted bands.
"""
# Fixing ax argument
if(ax==None):
ax = init_figure()
else:
ax==ax
# Fixing kpath argument
if(kpath==None):
raise ValueError("kpath is not provided. Use get_kpath() or export_vasprun().kpath to generate it.")
# Fixing bands arguement
if(bands==None):
raise ValueError("bands object is not provided. Use get_evals() or export_vasprun().bands to generate it.")
if(E_Fermi==None):
E_Fermi = bands.E_Fermi
# Plotting
if(bands.ISPIN==1):
if(not bands.evals.any()):
print(gu.color.y("Can not plot an empty eigenvalues object."))
return print(gu.color.g("Try with large energy range."))
en=bands.evals-E_Fermi
ax.plot(kpath,en,color=color1,lw=lw1,ls=style1)
if(bands.ISPIN==2):
if(not bands.evals.SpinUp.any()):
print(gu.color.y("Can not plot an empty eigenvalues object."))
return print(gu.color.g("Try with large energy range."))
enUp=bands.evals.SpinUp-E_Fermi
enDown=bands.evals.SpinDown-E_Fermi
ax.plot(kpath,enUp,color=color1,lw=lw1,ls=style1)
ax.plot(kpath,enDown,color=color2,lw=lw2,ls=style2)
# Legend only for spin polarized
if(showlegend==True):
ax.plot([],[],color=color1,lw=lw1,ls=style1,label='Spin Up')
ax.plot([],[],color=color2,lw=lw2,ls=style2,label='Spin Down')
ax.legend(fontsize='small',frameon=False,ncol=2, bbox_to_anchor=(0, 1), loc='lower left');
return ax
# Cell
[docs]def quick_bplot(path_evr=None,ax=None,skipk=None,joinPathAt=[],elim=[],xt_indices=[],\
xt_labels=[],E_Fermi=None,figsize=(3.4,2.6),txt=None,xytxt=[0.05,0.9],ctxt='black'):
"""
- Returns axes object and plot on which all matplotlib allowed actions could be performed.
- **Parameters**
- path_evr : path/to/vasprun.xml or output of `export_vasprun`. Auto picks in CWD.
- ax : Matplotlib axes object, if not given, one is created.
- skipk : Number of kpoints to skip, default will be from IBZKPT.
- joinPathAt : Points where kpath is broken.
- elim : [min,max] of energy range.
- E_Fermi : If not given, automatically picked from `export_vasprun`.
- xt_indices : High symmetry kpoints indices.abs
- xt_labels : High Symmetry kpoints labels.
- **kwargs : figsize=(3.4,2.6). Text,its position and color.
- **Returns**
- ax : matplotlib axes object with plotted bands.
"""
#checking type of given path.
if(path_evr==None):
vr=vp.export_vasprun(path=path_evr,skipk=skipk,elim=elim,joinPathAt=joinPathAt)
if(path_evr!=None):
from os import path as pt
if(type(path_evr)==vp.Dict2Data):
vr=path_evr
elif(pt.isfile(path_evr)):
vr=vp.export_vasprun(path=path_evr,skipk=skipk,elim=elim,joinPathAt=joinPathAt)
else:
return print("path_evr = `{}` does not exist".format(path_evr))
# Apply a robust final check.
try:
vr.bands;vr.kpath
except:
return print("Object: \n{} \nis like a lower tree of export_vasprun(). Expects top tree.".format(vr))
else:
# Main working here.
K=vr.kpath
xticks=[K[i] for i in xt_indices]
xlim=[min(K),max(K)]
if(elim):
ylim=[min(elim),max(elim)]
else:
ylim=[]
if(ax==None):
fig,ax=plt.subplots(1,1,figsize=figsize)
modify_axes(ax=ax,ylabel='Energy (eV)',xticks=xticks,xt_labels=xt_labels,xlim=xlim,ylim=ylim)
plot_bands(ax=ax,kpath=K,bands=vr.bands,showlegend=True,E_Fermi=E_Fermi,lw1=0.9)
if(txt!=None):
ax.text(*xytxt,txt,bbox=dict(edgecolor='white',facecolor='white', alpha=0.9),transform=ax.transAxes,color=ctxt)
else:
ax.text(*xytxt,vr.sys_info.SYSTEM,bbox=dict(edgecolor='white',facecolor='white', alpha=0.9),transform=ax.transAxes,color=ctxt)
return ax
# Cell
def add_text(ax=None,xs=0.05,ys=0.9,txts='[List]',colors='r'):
"""
- Adds text entries on axes, given single string or list.
- **Parameters**
- xs : List of x coordinates relative to axes or single coordinate.
- ys : List of y coordinates relative to axes or single coordinate.
- txts : List of strings or one string.
- colors: List of x colors of txts or one color.
"""
if ax==None:
raise ValueError("Matplotlib axes (ax) is not given.")
else:
if type(txts) == str:
ax.text(xs,ys,txts,bbox=dict(edgecolor='white',facecolor='white', alpha=0.9),transform=ax.transAxes,color=colors)
if(type(txts)==list):
for x,y,txt,color in zip(xs,ys,txts,colors):
ax.text(x,y,txt,bbox=dict(edgecolor='white',facecolor='white', alpha=0.9),transform=ax.transAxes,color=color)
return None
# Cell
def add_legend(ax=None,colors=[],labels=[],styles='solid',\
widths=0.7,anchor=(0,1), ncol=3,loc='lower left',fontsize='small',frameon=False,**legend_kwargs):
"""
- Adds custom legeneds on a given axes,returns None.
- **Parameters**
- ax : Matplotlib axes.
- colors : List of colors.
- labels : List of labels.
- styles : str or list of line styles.
- widths : str or list of line widths.
- **kwargs : Matplotlib's legend arguments.
"""
if(ax==None):
raise ValueError("Matplotlib axes (ax) is not given.")
else:
if(type(widths)==float or type(widths)==int):
if(type(styles)==str):
for color,label in zip(colors,labels):
ax.plot([],[],color=color,lw=widths,linestyle=styles,label=label)
else:
for color,label,style in zip(colors,labels,styles):
ax.plot([],[],color=color,lw=widths,linestyle=style,label=label)
else:
if(type(styles)==str):
for color,label,width in zip(colors,labels,widths):
ax.plot([],[],color=color,lw=width,linestyle=styles,label=label)
else:
for color,label,width,style in zip(colors,labels,widths,styles):
ax.plot([],[],color=color,lw=width,linestyle=style,label=label)
ax.legend(fontsize=fontsize,frameon=frameon,ncol=ncol, bbox_to_anchor=anchor, loc=loc,**legend_kwargs)
return None
# Cell
def add_colorbar(ax=None,cmap_or_clist=None,N=256,ticks=[1/6,1/2,5/6],\
ticklabels=['r','g','b'],vertical=False,fontsize=8):
"""
- Plots colorbar on a given axes. This axes should be only for colorbar. Returns None or throws ValueError for given colors.
- **Parameters**
- ax : Matplotlib axes object.
- cmap_or_clist: List/array of colors in or colormap's name. If None(default), first tries to get latest `quick_rgb_lines` colormap and if no success, then `RGB_f` colorbar is added. If nothing works, matplotlib's default colormap is plotted.
- N : int, number of color points Default 256.
- ticks : List of tick values to show on colorbar in interval [0,1].
- ticklabels : List of labels for ticks.
- vertical : Boolean, default is Fasle.
- fontsize : Default 8. Adjustable according to plot space.
- **Note**: Use `'RGB_f'` to map colors (after) plotted in `quick_rgb_lines` and use `'RGB_m'` to plot DOS in same colors.
"""
if(ax==None):
raise ValueError("Matplotlib axes (ax) is not given.")
else:
if cmap_or_clist is None:
try:
_hsv_ = plt.cm.get_cmap('RGB_f')
except:
colors=np.array([[1,0,1],[1,0,0],[1,1,0],[0,1,0],[0,1,1],[0,0,1],[1,0,1]])
_hsv_ = LSC.from_list('_hsv_',colors,N=N)
elif isinstance(cmap_or_clist,(list,np.ndarray)):
try:
_hsv_ = LSC.from_list('_hsv_',cmap_or_clist,N=N)
except Exception as e:
print(e,"\nFalling back to default color map!")
_hsv_ = None # fallback
elif isinstance(cmap_or_clist,str):
_hsv_ = cmap_or_clist #colormap name
else:
_hsv_ = None # default fallback
c_vals = np.linspace(0,1,N)
c_vals = np.vstack((c_vals,c_vals))
ticks_param = dict(direction='out',length=2,width=0.3,top=False,right=False,
grid_color=(1,1,1,0), grid_alpha=0)
if(vertical==False):
ax.imshow(c_vals,aspect='auto',cmap=_hsv_,origin='lower')
ax.set_yticks([])
ax.set_xticks([np.floor(N*t) for t in ticks])
ax.set_xticklabels(ticklabels,rotation=0)
ticks_param.update({'left':False})
ax.tick_params(**ticks_param)
if(vertical==True):
c_vals = c_vals.transpose()
ax.imshow(c_vals,aspect='auto',cmap=_hsv_,origin='lower')
ax.set_xticks([])
ax.set_yticks([np.floor(N*t) for t in ticks])
ax.set_yticklabels(ticklabels,rotation=90)
ticks_param.update({'bottom':False})
ax.tick_params(**ticks_param)
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(fontsize)
for child in ax.get_children():
if isinstance(child, mpl.spines.Spine):
child.set_color((1,1,1,0.4))
# Cell
def color_wheel(ax=None,
xy=(1,1),
scale = 0.12,
rlim=(0.2,1),
N=256,
color_map=None,
ticks=[1/6,1/2,5/6],
labels=['s','p','d'],
showlegend=True):
"""
- Returns cax i.e. color wheel axes.
- **Parameters**
- ax : Axes on which color wheel will be drawn. Auto created if not given.
- xy : (x,y) of center of wheel.
- scale : Scale of the cax internally generated by color wheel.
- rlim : Values in [0,1] interval, to make donut like shape.
- N : Number of segments in color wheel.
- color_map : Matplotlib's color map name. Auto picks `RGB_f` and if fails, fallbacks to `viridis`.
- ticks : Ticks in fractions in interval [0,1].
- labels : Ticks labels.
- showlegend: True or False.
"""
if ax is None:
ax = init_figure()
if color_map is None:
try: color_map = plt.cm.get_cmap('RGB_f')
except: color_map = 'viridis'
pos = ax.get_position()
ratio = pos.height/pos.width
cpos = [pos.x0+pos.width*xy[0]-scale/2,pos.y0+pos.height*xy[1]-scale/2,scale,scale]
cax = ax.get_figure().add_axes(cpos,projection='polar')
norm = mpl.colors.Normalize(0.0,2*np.pi)
t = np.linspace(0,2*np.pi,N)
r = np.linspace(*rlim,2)
rg,tg = np.meshgrid(r,t)
cax.pcolormesh(t,r,tg.T,norm=norm,cmap=color_map,edgecolor='face')
cax.set_yticklabels([])
cax.spines['polar'].set_visible(False)
##########
if showlegend == True:
colors = plt.cm.get_cmap(color_map)(ticks) # Get colors values.
labels = ["◾ "+l for l in labels]
labels[0] = labels[0]+'\n' #hack to write labels correctly on a single point.
labels[2] = '\n'+ labels[2]
for l,p,c in zip(labels,['bottom','center','top'],colors):
cax.text(rlim[1]+0.1,0.5,l,va=p,ha='left',color=c,transform=cax.transAxes,fontsize=9)
cax.set_xticklabels([])
else:
cax.set_xticks([t*2*np.pi for t in ticks])
cax.set_xticklabels(labels)
return cax
# Cell
def get_pros_data(kpath = None,
evals_set = None,
pros_set = None,
elements = [[0],],
orbs = [[0],],
interpolate = False,
scale_data = False,
n = 5,
k = 3
):
"""
- Returns selected elements/orbitals data.
- **Parameters**
- kapath : `export_vasprun`().kpath or `get_kpts`().kpath.
- evals_set: `export_vasprun`().bands.evals or `get_evals`().evals. If calculations are spin-polarized, it will be `...evals.SpinUp/SpinDown` for both.
- pros_set : `export_vasprun().pro_bands.pros` or `get_bands_pro_set`().pros. If calculations are spin-polarized, it will be `...pros.SpinUp/SpinDown` for both.
- elements : Lists of list of ions to project on, could be `range(start,stop,step)` as well, remember that `stop` is not included in python. so `range(0,2)` will generate 0 and 1 indices.
- orbs : List of lists of orbitals indices.
- scale_data : If True, normalizes projection data to 1.
- interpolate: Deafult is false, if True, it will add n points between nearest kpoints.
- n : int, default is 5. Adds n points between nearest kpoints.
- k : int, order of interpolation, defualt is 3. `n > k` should be hold.
- **Returns**
- A dictionary with keys 'kpath', 'evals', 'colors' that can be unpacked in next function's arguemnet.
- 'kpath' : Given or interpolated kpath of shape (NKPTS,).
- 'evals' : Given or interpolated evals_set of shape (NKPTS,NBANDS).
- 'pros': An array of shape (NKPTS,NBANDS,n) where n is length of input elements list. If scale_data = True, normalizes this array to 1.
"""
if not np.any(pros_set):
print(gu.color.y("Can not process an empty eigenvalues/projection object."))
return print(gu.color.g("Try with large `elim` in parent function."))
# Empty orbitals/elements are still allowed on purpose for rgb_lines.
for elem, orb in zip(elements,orbs):
if isinstance(elem,int) or isinstance(orb,int):
raise ValueError("Items in elements and orbs should be list, tuple or range, not int.")
arrays_list = []
for elem,orb in zip(elements,orbs):
pros = np.take(pros_set,elem,axis=0).sum(axis=0)
_pros = np.take(pros,orb,axis=2).sum(axis=2)
arrays_list.append(_pros)
_shape = (len(arrays_list),-1,np.shape(_pros)[1])
_colors = np.concatenate(arrays_list).reshape(_shape).transpose((1,2,0))
if scale_data: # Normalize overall data
c_max = np.max(_colors)
if c_max > 0.0000001: # Avoid division error
_colors = _colors/c_max
if(interpolate==True):
min_d, max_d = np.min(_colors),np.max(_colors) # For cliping
from pivotpy import g_utils as gu
knew, evals = gu.interpolate_data(kpath,evals_set,n=n,k=k)
_colors = gu.interpolate_data(kpath,_colors,n=n,k=k)[1].clip(min=min_d,max=max_d)
return {'kpath':knew,'evals': evals,'pros': _colors}
# Return if no interpolation
return {'kpath':kpath,'evals': evals_set,'pros': _colors}
def make_line_collection(max_width = None,
colors_list = None,
rgb = False,
uni_width = False,
scale_color = False,
**pros_data):
"""
- Returns a tuple of line collections. If rgb = True (at len(orbs) = 3 in `get_pros_data`), returns a tuple of two entries, multicolored line collection and RGB maximum values, otherwise return tuple of single colored multiple lines.
- **Parametrs**
- **pros_data: Output dictionary from `get_pros_data` containing kpath, evals and colors arrays.
- max_width : Default is None and max linwidth = 2.5. Max inewidth is scaled to max_width if an int of float is given.
- colors_list: List of colors for multiple lines, length equal to 3rd axis length of colors.
- rgb : Default is False. If True and np.shape(colors)[-1] == 3, RGB line collection is returned in a tuple of length 1. Tuple is just to support iteration.
- uni_width : Default is False, If True, makes linewidth uniform at width = max_width/2.
- scale_color: If True, normalizes each point's color value, as (0,0,0.5) --> (0,0,1). If False, clips colors in range [0,1] but does not effect linewidth.
"""
if not pros_data:
return print("No pros_data given.")
else:
kpath = pros_data.get('kpath')
evals = pros_data.get('evals')
pros = pros_data.get('pros')
req_args = [kpath, evals, pros]
for a,t in zip([kpath, evals, pros],['kpath', 'evals', 'pros']):
if not np.any(a):
return print("Missing {!r} from output of `get_pros_data()`".format(t))
# Average pros on two consecutive KPOINTS to get that patch color.
colors = pros[1:,:,:]/2 + pros[:-1,:,:]/2 # Near kpoints avearge
colors = colors.transpose((1,0,2)).reshape((-1,np.shape(colors)[-1])) # Must before lws
# Linwidths in orginal shape, but according to scale_data. Should be before scale_color
if uni_width:
lws = [max_width/2 if max_width else 2.5 for c in colors]
else:
# default linewidth = 2.5 unless specified otherwise
if rgb: # Single channel line widths
lws = 0.1 + 2.5*np.sum(colors,axis=1) #residual linewidth 0.1
else: # For separate lines
lws = 0.1 + 2.5*colors.T # .T to access in for loop.
if max_width:
# here division does not fail as min(lws)==0.1
lws = max_width*lws/(2.5*np.max(lws)) # Only if not None or zero.
# scale_color only after linewidths
if scale_color: # Normalize each point color
c_max = np.max(colors,axis=1)
c_max[c_max == 0] = 1 #Avoid division error:
for i in [0,1,2]:
colors[:,i] = colors[:,i]/c_max
else:
colors = colors.clip(min=0,max=1) # make sure RGB range.
if np.any(colors_list):
lc_colors = colors_list
else:
cmap = plt.cm.get_cmap('viridis')
lc_colors = cmap(np.linspace(0,1,np.shape(colors)[-1]))
lc_colors = lc_colors[:,:3] # Skip Alpha
# Reshaping data same as colors reshaped above, nut making line patches too.
kgrid = np.repeat(kpath,np.shape(evals)[1],axis=0).reshape((-1,np.shape(evals)[1]))
narr = np.concatenate((kgrid,evals),axis=0).reshape((2,-1,np.shape(evals)[1]))
marr = np.concatenate((narr[:,:-1,:],narr[:,1:,:]),axis=0).transpose().reshape((-1,2,2))
# Make Line collection
if rgb and np.shape(colors)[-1] == 3:
_scales = np.max(colors,axis=0) # Need for colorbar
return (LineCollection(marr,colors=colors,linewidths=lws),_scales)
else:
lcs = [LineCollection(marr,colors=_cl,linewidths=lw) for _cl,lw in zip(lc_colors,lws)]
return tuple(lcs)
def plot_collection(gpd_args,mlc_args,axes=None):
"""
- Plots line collection from the output of get_pros_data and make_line_collection on axes.
- **Parameters**
- gpd_args: Dictionary of arguments from function `get_pros_data`. Do not unpack it.
- mlc_args: Dictionary of arguments from function `make_line_collection`. Do not unpack it.
- axes : A single or list of matplotlib's Axes. len(list) should be equal to len(orbs) given in `get_pros_data`. If axes = None, auto generated.
- **Returns**
- axes: axes to return are spacially useful when axes = None, you can perform other actions on those axes. It will be a list of axes and all items could be same, depending on whether one are many axes were given/generated.
"""
if not np.any(axes):
axes = init_figure()
axes = np.array([axes]).ravel() # Safe list any axes size
pros_data = get_pros_data(**gpd_args)
line_coll = make_line_collection(**pros_data,**mlc_args)
if len(axes) != len(line_coll):
axes = [axes[0] for l_c in line_coll]
_ = [ax.add_collection(lc) for ax, lc in zip(axes,line_coll)]
_ = [ax.autoscale_view() for ax in axes]
return axes
# Cell
[docs]def quick_rgb_lines(path_evr = None,
ax = None,
skipk = None,
joinPathAt = [],
elim = [],
elements = [[0],[],[]],
orbs = [[0],[],[]],
labels = ['Elem0-s','',''],
max_width = None,
xt_indices = [0,-1],
xt_labels = [r'$\Gamma$','M'],
E_Fermi = None,
figsize = (3.4,2.6),
txt = None,
xytxt = [0.05,0.9],
ctxt = 'black',
uni_width = False,
interpolate = False,
spin = 'both',
n = 5,
k = 3,
scale_color = True,
scale_data = True,
colorbar = True,
color_matrix= None
):
"""
- Returns axes object and plot on which all matplotlib allowed actions could be performed. In this function,orbs,labels,elements all have list of length 3. Inside list, sublists or strings could be any length but should be there even if empty.
- **Parameters**
- path_evr : path/to/vasprun.xml or output of `export_vasprun`. Auto picks in CWD.
- ax : Matplotlib axes object, if not given, one is created.
- skipk : Number of kpoints to skip, default will be from IBZKPT.
- joinPathAt : Points where kpath is broken.
- elim : [min,max] of energy range.
- E_Fermi : If not given, automatically picked from `export_vasprun`.
- xt_indices : High symmetry kpoints indices.abs
- xt_labels : High Symmetry kpoints labels.
- elements : List [[0],[],[]] by default and plots s orbital of first ion..
- orbs : List [[r],[g],[b]] of indices of orbitals, could be empty, but shape should be same.
- labels : List [str,str,str] of projection labels. empty string should exist to maintain shape. Auto adds `↑`,`↓` for ISPIN=2. If a label is empty i.e. '', it will not show up in colorbar ticks or legend.
- max_width : Width to scale whole projections. if `uni_width=True, width=max_width/2`. Default is None and linewidth at any point = 2.5*sum(ions+orbitals projection of all three input at that point). Linewidth is scaled to max_width if an int or float is given.
- figsize : Tuple (width,height) in inches. Default (3.4.2.6) is article column's width.
- txt : Text on figure, if None, SYSTEM's name is printed.
- xytxt : [x_coord,y_coord] of text relative to axes.
- ctxt : color of text.
- uni_width : If True, width of bands kept uniform.
- spin : Plot spin-polarized for spin {'up','down','both'}. Default is both.
- interpolate: Default is False, if True, bands are interpolated.
- n : int, number of points, default is 5.
- k : int, order of interpolation 0,1,2,3. Defualt 3. `n > k` should be hold.
- scale_color: Boolean. Default True, colors are scaled to 1 at each point. If False, clips colors in range [0,1] but does not effect linewidth.
- scale_data : Default is True and normalizes projection data to 1. Has no visual effect if scale_color = True too.
- colorbar : Default is True. Displays a vertical RGB colorbar.
- color_matrix: Only works if `scale_color==True`. 3x3 or 3x4 numpy array or list to transform from RGB to another space,provided that sum(color_matrix[i,:3]) <= 1. 4th column, if given can be used to control the saturation,contrast and brightness as s,c,b = color_matrix[:,3] For simply changing the color intensity use np.diag([r,g,b]) with r,g,b interval in [0,1]. Try `pivotpy.color_matrix` as suggested color matrix and modify, which at s=0 returns gray scale.!
- **Returns**
- ax : matplotlib axes object with plotted projected bands.
- Registers as colormap `RGB_m` to use in DOS to plot in same colors and `RGB_f` to display bands colorbar on another axes.
> Note: Two figures made by this function could be comapred quantitatively only if `scale_data=False, max_width=None, scale_color=False` as these parameters act internally on data.
"""
# Fix orbitals, elements and labels lengths very early.
if len(elements) != len(orbs) or len(elements) != len(labels):
raise ValueError("`elements`, `orbs` and `labels` expect same length, even if empty.")
if path_evr == None:
path_evr = './vasprun.xml'
if type(path_evr) == vp.Dict2Data:
vr = path_evr
elif os.path.isfile(path_evr):
vr = vp.export_vasprun(path=path_evr,skipk=skipk,elim=elim,joinPathAt=joinPathAt)
else:
return print("path_evr = {!r} does not exist".format(path_evr))
# Apply a robust final check.
try:
vr.bands;vr.kpath
except:
return print("Object: \n{!r} \ndoes not match the output of `export_vasprun`.".format(vr))
# Main working here.
if vr.pro_bands == None:
print(gu.color.y("Can not plot an empty eigenvalues object."))
return print(gu.color.g("Try with large energy range."))
if not spin in ('up','down','both'):
raise ValueError("spin can take any of ['up','down'. 'both'] only.")
# Elements Index fixing.
elem_inds = vr.sys_info.ElemIndex
max_ind = elem_inds[-1]-1 # Last index is used for range in ElemIndex, not python index.
for i,elem in enumerate(elements.copy()):
if type(elem) == int:
try:
elements[i] = range(elem_inds[elem],elem_inds[elem+1])
info = "elements[{}] = {} is converted to {} which picks all ions of {!r}.".format(
i,elem,elements[i],vr.sys_info.ElemName[elem])
info += "To just pick one ion at this index, wrap it in brackets []."
print(gu.color.g(info))
except:
raise IndexError("Wrap elements[{}] in [] and try again.".format(i))
max_e = np.max([e for ee in elements for e in ee])
if max_e > max_ind:
raise IndexError("index {} is out of bound for {} ions".format(max_e,max_ind+1))
# Orbitals Index fixing
nfields = len(vr.pro_bands.labels)
orbs = [[item] if type(item) == int else item for item in orbs] #Fix if integer given.
max_o = np.max([r for rr in orbs for r in rr])
if max_o >= nfields:
return print("index {} is out of bound for {} orbs".format(max_o,nfields))
# Small things
if E_Fermi == None:
E_Fermi = vr.bands.E_Fermi
K = vr.kpath
xticks = [K[i] for i in xt_indices]
xlim = [min(K),max(K)]
if elim:
ylim = [min(elim),max(elim)]
else:
ylim = []
# Make axes if not given.
if not np.any([ax]):
ax = init_figure()
# Fix color_matrix here.
s,c,b = 1,1,0 # Saturation, contrast, brightness. Defaults
if color_matrix is not None:
color_matrix = np.array(color_matrix)
mix_matrix = color_matrix[:3,:3]
try:
s,c,b = color_matrix[:,3]
except: pass
colorbar_scales = [1,1,1] # default for colorbar
#=====================================================
def _add_collection(gpd_args,mlc_args,ax):
pros_data = get_pros_data(**gpd_args)
if scale_color and None not in np.unique(color_matrix):
colors = pros_data['pros']
c_max = np.max(colors,axis=0).max(axis=0) # ndims = 3
c_max[c_max == 0] = 1 #Avoid division error:
colors = colors/c_max
colors = gu.transform_color(colors,s=s,c=c,b=b,mixing_matrix=mix_matrix)
pros_data['pros'] = colors
mlc_args['scale_color'] = False # No more scaling in make_line_collection.
line_coll,scales = make_line_collection(**pros_data,**mlc_args)
colorbar_scales[:] = scales
ax.add_collection(line_coll)
ax.autoscale_view()
#====================================================
# After All Fixing
ISPIN=vr.sys_info.ISPIN
# Arguments for get_pros_data and make_line_collection
gpd_args = dict(elements=elements,orbs=orbs,interpolate=interpolate,n=n,k=k,scale_data=scale_data)
mlc_args = dict(rgb=True,uni_width=uni_width,scale_color = scale_color,colors_list= None, max_width=max_width)
if(ISPIN==1):
gpd_args.update(dict(kpath=K, evals_set=vr.bands.evals-E_Fermi, pros_set=vr.pro_bands.pros))
_add_collection(gpd_args,mlc_args,ax=ax)
if(ISPIN==2):
gpd_args1 = dict(kpath=K,evals_set=vr.bands.evals.SpinUp-E_Fermi,
pros_set=vr.pro_bands.pros.SpinUp,**gpd_args)
gpd_args2 = dict(kpath=K,evals_set=vr.bands.evals.SpinDown-E_Fermi,
pros_set=vr.pro_bands.pros.SpinDown,**gpd_args)
if(spin=='up'):
_add_collection(gpd_args1,mlc_args,ax=ax)
if(spin=='down'):
_add_collection(gpd_args2,mlc_args,ax=ax)
if(spin=='both'):
# Plotting Up
_add_collection(gpd_args1,mlc_args,ax=ax)
# Plotting Down
_add_collection(gpd_args2,mlc_args,ax=ax)
# Aethetcis of plot.
if not txt:
txt=vr.sys_info.SYSTEM
add_text(ax=ax,xs=xytxt[0],ys=xytxt[1],txts=txt,colors=ctxt)
modify_axes(ax=ax,xticks=xticks,xt_labels=xt_labels,xlim=xlim,ylim=ylim)
# Colorbar and Colormap
_colors_ = np.multiply([[1,0,1],[1,0,0],[1,1,0],[0,1,0],[0,1,1],
[0,0,1],[1,0,1]],colorbar_scales)
if scale_color and None not in np.unique(color_matrix): # Only apply color_matrix on scaled colors
zero_inds = np.where(colorbar_scales == 0)
mix_matrix[zero_inds,:] = 0 # No mixing for zero projection.
_colors_ = gu.transform_color(_colors_,s=s,c=c,b=b,mixing_matrix=mix_matrix)
# register a colormap to use in DOS of same color
from matplotlib.colors import LinearSegmentedColormap as LSC
plt.register_cmap('RGB_f',LSC.from_list('RGB_f',_colors_)) #Register cmap for Bands
plt.register_cmap('RGB_m',LSC.from_list('RGB_m',_colors_[1:-1])) #Register cmap for DOS
if colorbar:
_tls_ = [l for l in labels] # To avoid side effects, new labels array.
for i,label in enumerate(labels):
if label and ISPIN==2:
_tls_[i] = (label+'$^↑$' if spin=='up' else label+'$^↓$' if spin=='down' else label+'$^{↑↓}$')
w, h = ax.get_figure().get_size_inches()
w_f = 0.15/w # width of colorbar relative to fontsize=8
pos = ax.get_position()
# Make colobar space available from given axis itself
ax.set_position([pos.x0,pos.y0,pos.width-2.8*w_f,pos.height])
new_pos = [pos.x0+pos.width-w_f,pos.y0,w_f,pos.height]
axb = ax.get_figure().add_axes(new_pos)
add_colorbar(ax=axb,vertical=True,ticklabels=_tls_,cmap_or_clist = _colors_)
return ax
# Cell
[docs]def quick_color_lines(path_evr = None,
axes = None,
skipk = None,
joinPathAt = [],
elim = [],
elements = [[0]],
orbs = [[0]],
labels = ['s'],
color_map = 'gist_rainbow',
scale_data = False,
max_width = None,
spin = 'both',
xt_indices = [0, -1],
xt_labels = ['$\\Gamma$', 'M'],
E_Fermi = None,
showlegend = True,
figsize = (3.4, 2.6),
txt = None,
xytxt = [0.05, 0.85],
ctxt = 'black',
interpolate = False,
n = 5,
k = 3,
legend_kwargs = {'ncol': 4, 'anchor': (0, 1.05),
'handletextpad': 0.5, 'handlelength': 1,
'fontsize': 'small', 'frameon': False},
**subplots_adjust_kwargs):
"""
- Returns axes object and plot on which all matplotlib allowed actions could be performed. If given, elements, orbs, and labels must have same length. If not given, zeroth ion is plotted with s-orbital.
- **Parameters**
- path_evr : Path/to/vasprun.xml or output of `export_vasprun`. Auto picks in CWD.
- axes : Matplotlib axes object with one or many axes, if not given, auto created.
- skipk : Number of kpoints to skip, default will be from IBZKPT.
- joinPathAt : Points where kpath is broken.
- elim : [min,max] of energy range.
- E_Fermi : If not given, automatically picked from `export_vasprun`.
- xt_indices : High symmetry kpoints indices.abs
- xt_labels : High Symmetry kpoints labels.
- elements : List [[0],], by defualt and plot first ion's projections.
- orbs : List [[0],] lists of indices of orbitals, could be empty.
- labels : List [str,] of orbitals labels. len(labels)==len(orbs) must hold. Auto adds `↑`,`↓` for ISPIN=2. If a label is empty i.e. '', it will not show up in legend.
- color_map : Matplotlib's standard color maps. Default is 'gist_ranibow'.
- showlegend : True by defualt and displays legend relative to axes[0]. If False, it writes text on individual ax.
- scale_data : Default is False, If True, normalize projection data to 1.
- max_width : Width to scale whole projections. Default is None and linewidth at any point on a line = 2.5*sum(ions+orbitals projection of the input for that line at that point). Linewidth is scaled to max_width if an int or float is given.
- figsize : Tuple (width,height) in inches. Default (3.4.2.6) is article column's width.
- txt : Text on figure, if None, SYSTEM's name is printed.
- xytxt : [x_coord,y_coord] of text relative to axes.
- ctxt : color of text.
- spin : Plot spin-polarized for spin {'up','down','both'}. Default is both.
- interpolate: Default is False, if True, bands are interpolated.
- n : int, number of points, default is 5.
- k : int, order of interpolation 0,1,2,3. Defualt 3. `n > k` should be hold.
- legend_kwargs: Dictionary containing legend arguments.
- **subplots_adjust_kwargs : plt.subplots_adjust parameters.
- **Returns**
- ax : matplotlib axes object with plotted projected bands.
> Note: Two figures made by this function could be comapred quantitatively only if `scale_data=False, max_width=None` as these parameters act internally on data.
"""
# Fix orbitals, elements and labels lengths very early.
if len(elements) != len(orbs) or len(elements) != len(labels):
raise ValueError("`elements`, `orbs` and `labels` expect same length, even if empty.")
if path_evr == None:
path_evr = './vasprun.xml'
if type(path_evr) == vp.Dict2Data:
vr = path_evr
elif os.path.isfile(path_evr):
vr = vp.export_vasprun(path=path_evr,skipk=skipk,elim=elim,joinPathAt=joinPathAt)
else:
return print("path_evr = {!r} does not exist".format(path_evr))
# Apply a robust final check.
try:
vr.bands;vr.kpath
except:
return print("Object: \n{!r} \ndoes not match the output of `export_vasprun`.".format(vr))
# Main working here.
if vr.pro_bands == None:
print(gu.color.y("Can not plot an empty eigenvalues object."))
return print(gu.color.g("Try with large energy range."))
if not spin in ('up','down','both'):
raise ValueError("spin can take any of ['up','down'. 'both'] only.")
# Elements Index fixing.
elem_inds = vr.sys_info.ElemIndex
max_ind = elem_inds[-1]-1 # Last index is used for range in ElemIndex, not python index.
for i,elem in enumerate(elements.copy()):
if type(elem) == int:
try:
elements[i] = range(elem_inds[elem],elem_inds[elem+1])
info = "elements[{}] = {} is converted to {} which picks all ions of {!r}.".format(i,elem,elements[i],vr.sys_info.ElemName[elem])
info += "To just pick one ion at this index, wrap it in brackets []."
print(gu.color.g(info))
except:
raise IndexError("Wrap elements[{}] in [] and try again.".format(i))
max_e = np.max([e for ee in elements for e in ee])
if max_e > max_ind:
raise IndexError("index {} is out of bound for {} ions".format(max_e,max_ind+1))
# Orbitals Index fixing
nfields = len(vr.pro_bands.labels)
orbs = [[item] if type(item) == int else item for item in orbs] #Fix if integer given.
max_o = np.max([r for rr in orbs for r in rr])
if max_o >= nfields:
return print("index {} is out of bound for {} orbs".format(max_o,nfields))
# Small things
if E_Fermi == None:
E_Fermi = vr.bands.E_Fermi
K = vr.kpath
xticks = [K[i] for i in xt_indices]
xlim = [min(K),max(K)]
if elim:
ylim = [min(elim),max(elim)]
else:
ylim = []
# Fix elements and colors length. ISPIN 2 case is done in loop itself
if color_map not in plt.colormaps():
c_map = plt.cm.get_cmap('viridis')
print("color_map = {!r} not exists, falling back to default color map.".format(color_map))
else:
c_map = plt.cm.get_cmap(color_map)
c_vals = np.linspace(0,1,len(orbs))
colors = c_map(c_vals)
# After All Fixing
ISPIN=vr.sys_info.ISPIN
# Arguments for get_pros_data and make_line_collection
gpd_args = dict(elements=elements,orbs=orbs,interpolate=interpolate,n=n,k=k,scale_data=scale_data)
mlc_args = dict(rgb=False,uni_width=False,scale_color = False,colors_list= colors, max_width=max_width)
if(ISPIN==1):
gpd_args.update(dict(kpath=K,evals_set=vr.bands.evals-E_Fermi,pros_set=vr.pro_bands.pros))
axes = plot_collection(gpd_args,mlc_args,axes=axes)
if ISPIN == 2:
gpd_args1 = dict(kpath=K,evals_set=vr.bands.evals.SpinUp-E_Fermi,
pros_set=vr.pro_bands.pros.SpinUp,**gpd_args)
gpd_args2 = dict(kpath=K,evals_set=vr.bands.evals.SpinDown-E_Fermi,
pros_set=vr.pro_bands.pros.SpinDown,**gpd_args)
if spin=='up':
axes = plot_collection(gpd_args1,mlc_args,axes=axes)
if spin=='down':
axes = plot_collection(gpd_args2,mlc_args,axes=axes)
if spin=='both':
axes = plot_collection(gpd_args1,mlc_args,axes=axes)
# Axes from above are reused in spin down plot. to avoid double graph
axes = plot_collection(gpd_args2,mlc_args,axes=axes)
# Aesthetics
_tls_ = [l for l in labels] # To avoid side effects, new labels array.
for i,label in enumerate(labels):
if label and ISPIN==2:
_tls_[i] = (label+'$^↑$' if spin=='up' else label+'$^↓$' if spin=='down' else label+'$^{↑↓}$')
if showlegend:
width = (max_width/2 if max_width else 2.5)
add_legend(ax=axes[0],colors=colors,labels=_tls_,widths=width,**legend_kwargs)
else:
x,y=[*xytxt]
_ = [add_text(ax=ax,xs=x,ys=y,txts=_tl_,colors=ctxt) for ax,_tl_ in zip(axes,_tls_)]
_ = [modify_axes(ax=ax,xticks=xticks,xt_labels=xt_labels,xlim=xlim,ylim=ylim) for ax in axes]
plt.subplots_adjust(**subplots_adjust_kwargs)
return axes
# Cell
def select_pdos(tdos = None,
pdos_set = None,
ions = [0,],
orbs = [0,],
E_Fermi = 0,
interpolate = False,
n = 5,
k = 3,
):
"""
- Returns (interpolated/orginal) enrgy(N,), tdos(N,), and pdos(N,) of selected ions/orbitals.
- **Parameters**
- tdos : `export_vasprun`().tdos or `get_tdos`().tdos. If calculations are spin-polarized, it will be `..tdos.SpinUp/SpinDown` for both. You need to apply this function twice for SpinUp and SpinDown separately.
- pdos_set : `export_vasprun().pro_dos.pros` or `get_dos_pro_set`().pros. If calculations are spin-polarized, it will be `...pros.SpinUp/SpinDown` for both.
- ions : List of ions to project on, could be `range(start,stop,step)` as well, remember that `stop` is not included in python. so `range(0,2)` will generate 0 and 1 indices.
- orbs : List of orbitals indices to pick.
- E_Fermi : Here it is zero. Needs to be input.
- interpolate: Deafult is false, if True, it will add n points between nearest points.
- n : int, default is 5. Adds n points between nearest kpoints.
- k : int, order of interpolation, defualt is 3. `n > k` should be hold.
"""
if tdos==[]:
return print(gu.color.y("Can not plot empty DOS."))
en = tdos[:,0]-E_Fermi
t_dos = tdos[:,1]
pros = np.take(pdos_set[:,:,1:],list(ions),axis=0).sum(axis=0)
p_dos = np.take(pros,orbs,axis=1).sum(axis=1)
if interpolate == True:
from pivotpy import g_utils as gu
_en,_tdos=gu.interpolate_data(en,t_dos,n=n,k=k)
_tdos=_tdos.clip(min=0)
_en,_pdos=gu.interpolate_data(en,p_dos,n=n,k=k)[1].clip(min=0)
else:
_en,_tdos,_pdos=pdos_set[0,:,0]-E_Fermi,t_dos,p_dos # reading _en from projected dos if not interpolated.
return _en,_tdos,_pdos
# Cell
def collect_dos(path_evr = None,
elim = [],
elements = [[0],],
orbs = [[0],],
labels = ['s',],
E_Fermi = None,
spin = 'both',
interpolate = False,
n = 5,
k = 3
):
"""
- Returns lists of energy,tdos, pdos and labels. If given,elements,orbs and labels must have same length. If not given, zeroth ions is collected with s-orbital.
- **Parameters**)
- path_evr : Path/to/vasprun.xml or output of `export_vasprun`. Auto picks in CWD.
- elim : [min,max] of energy range.
- E_Fermi : If not given, automatically picked from `export_vasprun`.
- elements : List [[0],], by defualt and plot first ion's projections.
- orbs : List [[0],] lists of indices of orbitals, could be empty.
- labels : List [str,] of orbitals labels. len(labels)==len(orbs) must hold. Auto adds `↑`,`↓` for ISPIN=2.
- spin : Plot spin-polarized for spin {'up','down','both'}. Default is both.
- interpolate: Default is False, if True, bands are interpolated.
- n : int, number of points, default is 5.
- k : int, order of interpolation 0,1,2,3. Defualt 3. `n > k` should be hold.
- **Returns**
- Energy : (N,1) size.
- tdos : (N,1) size or [(N,1),(N,1)] if spin polarized.
- pdos : [(N,1),(N,1),...], spin polarized is auto-fixed.
- labels : ['label1,'label2',...] spin polarized is auto-fixed.
- vr : Exported vasprun.
"""
#checking type of given path.
if(path_evr==None):
vr=vp.export_vasprun(path=path_evr,elim=elim)
if(path_evr!=None):
if(type(path_evr)==vp.Dict2Data):
vr=path_evr
elif(os.path.isfile(path_evr)):
vr=vp.export_vasprun(path=path_evr,elim=elim)
else:
return print("path_evr = `{}` does not exist".format(path_evr))
# Apply a robust final check.
try:
vr.tdos;vr.pro_dos
except:
return print("Object: \n{} \nis like a lower tree of export_vasprun(). Expects top tree.".format(vr))
else:
# Main working here.
if(vr.pro_dos==None):
return print(gu.color.y("Can not plot an empty DOS object."))
if not spin in ('up','down','both'):
raise ValueError(
"spin can take `up`,`down` or `both` values only.")
return
orbs=[[item] if type(item)==int else item for item in orbs] #Fix if integer given.
elements=[[item] if type(item)==int else item for item in elements] #Fix if integer given.
elem_inds = vr.sys_info.ElemIndex
max_ind = elem_inds[-1]-1 # Last index is used for range in ElemIndex, not python index.
for i,elem in enumerate(elements.copy()):
if(type(elem)==int):
if(elem <= max_ind):
elements[i]=range(i,i+1)
else:
return print("index {} is out of bound for ions of length {}".format(i,max_ind+1))
if(E_Fermi==None):
E_Fermi=vr.tdos.E_Fermi
nfields = len(vr.pro_dos.labels) - 1 #
# First fix orbitals
if len(elements)!=len(orbs) or len(elements)!=len(labels):
raise ValueError("elements, orbs and labels expect same length even if their entries are empty.")
return
# After All Fixing
ISPIN=vr.sys_info.ISPIN
e,ts,ps,ls=None,None,[],[] # to collect all total/projected dos.
for elem,orb,label in zip(elements,orbs,labels):
args_dict=dict(ions=elem,orbs=orb,interpolate=interpolate,n=n,k=k,E_Fermi=E_Fermi)
if ISPIN==1:
tdos=vr.tdos.tdos
pdos_set=vr.pro_dos.pros
e,t,p = select_pdos(tdos=tdos,pdos_set=pdos_set, **args_dict)
ps.append(p)
ls.append(label)
ts = t
if ISPIN==2:
tdos1=vr.tdos.tdos.SpinUp
tdos2=vr.tdos.tdos.SpinDown
pdos_set1=vr.pro_dos.pros.SpinUp
pdos_set2=vr.pro_dos.pros.SpinDown
if spin=='up':
e,t1,p1 = select_pdos(tdos=tdos1,pdos_set=pdos_set1, **args_dict)
ps.append(p1)
ls.append((label+'$^↑$' if label else ''))
ts = t1
if spin=='down':
e,t2,p2 = select_pdos(tdos=tdos2,pdos_set=pdos_set2, **args_dict)
ps.append(p2)
ls.append((label+'$^↓$' if label else ''))
ts = t2
if spin=='both':
e,t1,p1 = select_pdos(tdos=tdos1,pdos_set=pdos_set1, **args_dict)
ps.append(p1)
ls.append((label+'$^↑$' if label else ''))
e,t2,p2 = select_pdos(tdos=tdos2,pdos_set=pdos_set2, **args_dict)
ps.append(-p2)
ls.append((label+'$^↓$' if label else ''))
ts=[t1,-t2]
return e,ts,ps,ls,vr
# Cell
[docs]def quick_dos_lines(path_evr = None,
ax = None,
elim = [],
include_dos = 'both',
elements = [[0],],
orbs = [[0],],
labels = ['s',],
color_map = 'gist_rainbow',
tdos_color = (0.8,0.95,0.8),
linewidth = 0.5,
fill_area = True,
vertical = False,
E_Fermi = None,
figsize = (3.4,2.6),
txt = None,
xytxt = [0.05,0.85],
ctxt = 'black',
spin = 'both',
interpolate = False,
n = 5,
k = 3,
showlegend = True,
legend_kwargs = {'ncol' : 4,
'anchor' : (0,1),
'handletextpad' : 0.5,
'handlelength' : 1,
'fontsize' : 'small',
'frameon' : False
}
):
"""
- Returns ax object (if ax!=False) and plot on which all matplotlib allowed actions could be performed, returns lists of energy,tdos and pdos and labels. If given,elements,orbs colors, and labels must have same length. If not given, zeroth ions is plotted with s-orbital.
- **Parameters**)
- path_evr : Path/to/vasprun.xml or output of `export_vasprun`. Auto picks in CWD.
- ax : Matplotlib axes object, if None, one is created. If False, data lists are returned.
- include_dos: One of {'both','tdos','pdos'}.
- elim : [min,max] of energy range.
- E_Fermi : If not given, automatically picked from `export_vasprun`.
- elements : List [[0],], by defualt and plot first ion's projections.
- orbs : List [[0],] lists of indices of orbitals, could be empty.
- labels : List [str,] of orbitals labels. len(labels)==len(orbs) must hold. Auto adds `↑`,`↓` for ISPIN=2.
- color_map : Matplotlib's standard color maps. Default is 'gist_ranibow'. Use 'RGB' if want to compare with `quick_rgb_lines` with 3 projection inputs (len(orbs)==3).
- fill_area : Default is True and plots filled area for dos. If False, plots lines only.
- vertical : False, If True, plots along y-axis.
- showlegend : True by defualt.
- figsize : Tuple (width,height) in inches. Default (3.4.2.6) is article column's width.
- txt : Text on figure, if None, SYSTEM's name is printed.
- xytxt : [x_coord,y_coord] of text relative to axes.
- ctxt : color of text.
- spin : Plot spin-polarized for spin {'up','down','both'}. Default is both.
- interpolate: Default is False, if True, bands are interpolated.
- n : int, number of points, default is 5.
- k : int, order of interpolation 0,1,2,3. Defualt 3. `n > k` should be hold.
- legend_kwargs: Dictionary to contain legend arguments to fix.
- **Returns**
- ax : Matplotlib axes.
"""
if(include_dos not in ('both','pdos','tdos')):
return print("`include_dos` expects one of ['both','pdos','tdos'], got {}.".format(include_dos))
en,tdos,pdos,vr=None,None,None,None # Placeholders for defining. must be here.
cl_dos=collect_dos(path_evr=path_evr,
elim=elim,
elements=elements,
orbs=orbs,
labels=labels,
E_Fermi=E_Fermi,
spin=spin,
interpolate=interpolate,
n=n,
k=k)
try:
en,tdos,pdos,labels,vr=cl_dos # Labels updated
except TypeError:
from .g_utils import color
return print(gu.color.g("Try with large energy range."))
# Fix elements and colors length
if color_map in plt.colormaps():
from matplotlib.pyplot import cm
if len(tdos) == 2:
c_map = cm.get_cmap(color_map)
c_vals = np.linspace(0,1,2*len(orbs))
colors = c_map(c_vals)
else:
c_map = cm.get_cmap(color_map)
c_vals = np.linspace(0,1,len(orbs))
colors = c_map(c_vals)
# Fix for RGB comparison
if len(tdos) == 2 and 'both' in spin and len(orbs)==3:
colors[[-1,-2]]= colors[[-2,-1]] #Flip last two colors only
else:
return print("`color_map` expects one of the follwoing:\n{}".format(plt.colormaps()))
# Make additional colors for spin down. Inverted colors are better.
t_color=mpl.colors.to_rgb(tdos_color)
it_color=gu.transform_color(t_color,c=-1) # -1 contrast inverts color
if(ax==None):
ax = init_figure(figsize=figsize)
if(vertical==False):
if(fill_area==False):
if(include_dos!='pdos'):
if(len(tdos)==2): # Spin polarized.
ax.plot(en,tdos[0],color=(t_color),label=r'TDOS$^↑$',lw=linewidth)
ax.plot(en,tdos[1],color=(it_color),label=r'TDOS$^↓$',lw=linewidth)
else: # unpolarized.
ax.plot(en,tdos,color=(t_color),label='TDOS',lw=linewidth)
if(include_dos!='tdos'):
for p,l,c in zip(pdos,labels,colors):
ax.plot(en,p, color=(c),linewidth=linewidth,label=l)
if(fill_area==True):
if(include_dos!='pdos'):
if(len(tdos)==2): # Spin polarized.
ax.fill_between(en,tdos[0],color=(t_color),label=r'TDOS$^↑$',lw=0)
ax.fill_between(en,tdos[1],color=(it_color),label=r'TDOS$^↓$',lw=0)
else: # unpolarized.
ax.fill_between(en,tdos,color=(t_color),label='TDOS',lw=0)
if(include_dos!='tdos'):
for p,l,c in zip(pdos,labels,colors):
ax.fill_between(en,p,color=(mpl.colors.to_rgba(c,0.4)),linewidth=0)
ax.plot(en,p, color=(c),linewidth=linewidth,label=l)
if elim:
ax.set_xlim([min(elim),max(elim)])
if(vertical==True):
if(fill_area==False):
if(include_dos!='pdos'):
if(len(tdos)==2): # Spin polarized.
ax.plot(tdos[0],en,color=(t_color),label=r'TDOS$^↑$',lw=linewidth)
ax.plot(tdos[1],en,color=(it_color),label=r'TDOS$^↓$',lw=linewidth)
else: # unpolarized.
ax.plot(tdos,en,color=(t_color),label='TDOS',lw=linewidth)
if(include_dos!='tdos'):
for p,l,c in zip(pdos,labels,colors):
ax.plot(p,en, color=(c),linewidth=linewidth,label=l)
if(fill_area==True):
if(include_dos!='pdos'):
if(len(tdos)==2): # Spin polarized.
ax.fill_betweenx(en,tdos[0],color=(t_color),label=r'TDOS$^↑$',lw=0)
ax.fill_betweenx(en,tdos[1],color=(it_color),label=r'TDOS$^↓$',lw=0)
else: # unpolarized.
ax.fill_betweenx(en,tdos,color=(t_color),label='TDOS',lw=0)
if(include_dos!='tdos'):
for p,l,c in zip(pdos,labels,colors):
ax.fill_betweenx(en,p,color=(mpl.colors.to_rgba(c,0.4)),linewidth=0)
ax.plot(p,en, color=(c),linewidth=linewidth,label=l)
if elim:
ax.set_ylim([min(elim),max(elim)])
if(showlegend==True):
add_legend(ax=ax,labels=[],colors=colors,widths=linewidth,**legend_kwargs)
return ax
# Cell
[docs]def plt_to_html(plt_fig=None,transparent=True,dash_html=None):
"""
- Returns base64 encoded Image to display in notebook or HTML <svg> or plotly's dash_html_components.Img object.
- **Parameters**
- plt_fig : Matplotlib's figure instance, auto picks as well.
- transparent: True of False for fig background.
- dash_html : Default is None which results in an image display in jupyter notebook.
- If True, returns html.Img object for plotly's dash.
- If False, returns <svg> object to embed in HTML DOM.
"""
if plt_fig==None:
plt_fig = plt.gcf()
plot_bytes = BytesIO()
plt.savefig(plot_bytes,format='svg',transparent=transparent)
if dash_html == None:
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell' or shell=='Shell': #Shell for Colab. Don't know why Google ...
_ = plt.clf() # Clear other display
return HTML('<svg' + plot_bytes.getvalue().decode('utf-8').split('<svg')[1])
except:
return plt.show()
elif dash_html==False:
_ = plt.clf() # Clear image
return '<svg' + plot_bytes.getvalue().decode('utf-8').split('<svg')[1]
else:
import base64
_ = plt.clf() # Clear image
import dash_html_components as html
img = base64.b64encode(plot_bytes.getvalue())
return html.Img(src="data:image/svg+xml;base64,{}".format(img.decode('utf-8')))
# Cell
[docs]def plot_potential(basis = None,
e_or_m=None,
operation='mean_z',
ax=None,
period=None,
lr_pos=(0.25,0.75),
lr_widths = [0.5,0.5],
labels=(r'$V(z)$',r'$\langle V \rangle _{roll}(z)$',r'$\langle V \rangle $'),
colors = ((0,0.2,0.7),'b','r'),
annotate = True
):
"""
- Returns tuple(ax,Data) where Data contains resultatnt parameters of averaged potential of LOCPOT.
- **Parameters**
- basis : `export_potential().basis`.
- e_or_m : `epxort_potential().[e,m,m_x,m_y,m_z]` is 3D grid data. As `epxort_potential` is slow, so compute it once and then plot the output data.
- operation: Default is 'mean_z'. What to do with provided volumetric potential data. Anyone of these 'mean_x','min_x','max_x','mean_y','min_y','max_y','mean_z','min_z','max_z'.
- ax: Matplotlib axes, if not given auto picks.
- period: Periodicity of potential in fraction between 0 and 1. For example if a slab is made of 4 super cells in z-direction, period=0.25.
- lr_pos: Locations around which averages are taken.Default (0.25,0.75). Provide in fraction between 0 and 1. Center of period is located at these given fractions. Work only if period is given.
- lr_widths: Default is [0.5,0.5], you may have slabs which have different lengths on left and right side. Provide a pair proportional to widths e.g (1,1), (1,1.1) etc. and it is auto normalized to 1. Works only if period is given.
- labels: List of three labels for legend. Use plt.legend() or pp.add_legend() for labels to appear. First entry is data plot, second is its convolution and third is complete average.
- colors: List of three colors for lines.
- annotate: True by default, writes difference of right and left averages on plot.
"""
check = ['mean_x','min_x','max_x','mean_y','min_y','max_y','mean_z','min_z','max_z']
if operation not in check:
return print("`operation` excepts any of {}, got {}".format(check,operation))
if ax is None:
ax = init_figure()
if e_or_m is None:
print('`e_or_m` not given, trying to autopick LOCPOT...')
try:
ep = gu.export_potential()
basis = ep.basis
e_or_m= ep.e
except:
return print('Could not auto fix. Make sure `basis` and `e_or_m` are provided.')
if 'min' in operation:
if '_x' in operation:
pot = e_or_m.min(axis=2).min(axis=1)
if '_y' in operation:
pot = e_or_m.min(axis=2).min(axis=0)
if '_z' in operation:
pot = e_or_m.min(axis=0).min(axis=0)
elif 'max' in operation:
if '_x' in operation:
pot = e_or_m.max(axis=2).max(axis=1)
if '_y' in operation:
pot = e_or_m.max(axis=2).max(axis=0)
if '_z' in operation:
pot = e_or_m.max(axis=0).max(axis=0)
else: #mean by default
if '_x' in operation:
pot = e_or_m.mean(axis=2).mean(axis=1)
if '_y' in operation:
pot = e_or_m.mean(axis=2).mean(axis=0)
if '_z' in operation:
pot = e_or_m.mean(axis=0).mean(axis=0)
# Direction axis
if '_x' in operation:
x = np.linalg.norm(basis[0])*np.linspace(0,1,len(pot))
if '_y' in operation:
x = np.linalg.norm(basis[1])*np.linspace(0,1,len(pot))
if '_z' in operation:
x = np.linalg.norm(basis[2])*np.linspace(0,1,len(pot))
ax.plot(x,pot,lw=0.8,c=colors[0],label=labels[0]) #Potential plot
ret_dict = {'direction':operation.split('_')[1]}
# Only go below if periodicity is given
if period == None:
return (ax,ret_dict) # Simple Return
if period != None:
period = int(period*len(pot))
arr_con = np.convolve(pot, np.ones((period,))/period, mode='valid')
x_con = np.linspace(0,x[-1],len(arr_con))
ax.plot(x_con,arr_con,linestyle='dashed',lw=0.7,label=labels[1],c=colors[1]) # Convolved plot
# Find Averages
left,right = lr_pos
ind_1 = int(left*len(pot))
ind_2 = int(right*len(pot))
lr_widths = [l/sum(lr_widths) for l in lr_widths] # Normalize widths
div_1 = int(period*lr_widths[0]) # left half range
div_2 = int(period*lr_widths[1]) # right half range
v_1 = np.mean(pot[ind_1-div_1:ind_1+div_1+1]).astype(float)
v_2 = np.mean(pot[ind_2-div_2:ind_2+div_2+1]).astype(float)
ret_dict.update({'left':{'v':float(v_1),'av_range':[ind_1-div_1,ind_1+div_1]}})
ret_dict.update({'right':{'v':float(v_2),'av_range':[ind_2-div_2,ind_2+div_2]}})
ret_dict.update({'deltav':float(v_2-v_1)})
#Level plot
middle = int((ind_1+ind_2)/2)
ax.plot([x[ind_1],x[middle],x[middle], x[ind_2]],[v_1,v_1,v_2,v_2],c=colors[2],lw=0.7)
# Thick Plots
ax.plot([x[ind_1-div_1],x[ind_1+div_1]],[v_1,v_1],c=colors[2],lw=2,label=labels[2])
ax.plot([x[ind_2-div_2],x[ind_2+div_2]],[v_2,v_2],c=colors[2],lw=2)
# Annotate
if annotate == True:
ax.text(0.5,0.07,r'$\Delta _{R,L} = %9.6f$'%(np.round(v_2-v_1,6)),ha="center", va="center",
bbox=dict(edgecolor='white',facecolor='white', alpha=0.5),transform=ax.transAxes)
ax.set_xlabel('$'+ret_dict['direction']+' ('+u'\u212B'+')$')
ax.set_xlim([x[0],x[-1]])
return (ax,vp.Dict2Data(ret_dict))