# AUTOGENERATED! DO NOT EDIT! File to edit: StructureIO.ipynb (unless otherwise specified).
__all__ = ['save_mp_API', 'load_mp_data', 'get_crystal', 'get_poscar', 'get_kpath', 'get_kmesh', 'intersect_3p_p_3v',
'centroid', 'order', 'in_vol_sector', 'out_bz_plane', 'to_xy', 'rad_angle', 'arctan_full', 'get_bz',
'plot_bz']
# Cell
import os
import json
import math
import numpy as np
from pathlib import Path
import requests as req
from collections import namedtuple
from itertools import permutations
from itertools import product
from scipy.spatial import ConvexHull
from scipy.spatial import Voronoi
import plotly.graph_objects as go
import pivotpy.vr_parser as vp
# Cell
[docs]def save_mp_API(api_key):
"""
- Save materials project api key for autoload in functions.
"""
home = str(Path.home())
file = os.path.join(home,'.pivotpyrc')
lines = []
if os.path.isfile(file):
with open(file,'r') as fr:
lines = fr.readlines()
lines = [line for line in lines if 'MP_API_KEY' not in line]
fr.close()
with open(file,'w') as fw:
fw.write("MP_API_KEY = {}".format(api_key))
for line in lines:
fw.write(line)
fw.close()
# Cell
def load_mp_data(formula,api_key=None,mp_id=None,max_sites = None):
"""
- Returns fetched data using request api of python form materials project website.
- **Parameters**
- formula : Material formula such as 'NaCl'.
- api_key : API key for your account from material project site. Auto picks if you already used `save_mp_API` function.
- mp_id : Optional, you can specify material ID to filter results.
-max_sites : Option, you can set maximum number of sites to load fastly as it will not fetch all large data sets.
"""
if api_key is None:
try:
home = str(Path.home())
file = os.path.join(home,'.pivotpyrc')
with open(file,'r') as f:
lines=f.readlines()
for line in lines:
if 'MP_API_KEY' in line:
api_key = line.split('=')[1].strip()
except:
return print("api_key not given. provide in argument or generate in file using `save_mp_API(your_mp_api_key)")
url = "https://www.materialsproject.org/rest/v2/materials/_____/vasp?API_KEY=|||||"
url = url.replace('_____',formula).replace('|||||',api_key)
resp = req.request(method='GET',url=url)
jl = json.loads(resp.text)
all_res = jl['response']
if max_sites != None:
sel_res=[]
for res in all_res:
if res['nsites'] <= max_sites:
sel_res.append(res)
return sel_res
# Filter to mp_id at last. more preferred
if mp_id !=None:
for res in all_res:
if mp_id == res['material_id']:
return [res]
return all_res
# Cell
def get_crystal(formula,api_key=None,mp_id=None,max_sites = None):
"""
- Returns crystal information dictionary including cif data format.
- **Parameters**
- formula : Material formula such as 'NaCl'.
- api_key : API key for your account from material project site. Auto picks if you already used `save_mp_API` function.
- mp_id : Optional, you can specify material ID to filter results.
-max_sites : Option, you can set maximum number of sites to load fastly as it will not fetch all large data sets.
"""
all_res = load_mp_data(formula=formula,api_key = api_key, mp_id = mp_id, max_sites = max_sites)
cifs = []
for res in all_res:
cif = res['cif']
symbol = res['spacegroup']['symbol']
crystal = res['spacegroup']['crystal_system']
unit = res['unit_cell_formula']
mp_id = res['material_id']
crs = dict(mp_id = mp_id,symbol = symbol, crystal = crystal, unit = unit, cif = cif)
cifs.append(vp.Dict2Data(crs))
return cifs
# Cell
[docs]def get_poscar(formula ,api_key=None,mp_id=None,max_sites = None):
"""
- Returns poscar information dictionary including cif data format.
- **Parameters**
- formula : Material formula such as 'NaCl'.
- api_key : API key for your account from material project site. Auto picks if you already used `save_mp_API` function.
- mp_id : Optional, you can specify material ID to filter results.
-max_sites : Option, you can set maximum number of sites to load fastly as it will not fetch all large data sets.
- **Usage**
- `get_poscar('GaAs',api_key,**kwargs)`. Same result is returned from `Get-POSCAR` command in PowerShell terminal if Vasp2Visual module is installed.
"""
crys = get_crystal(formula = formula,api_key = api_key, mp_id = mp_id, max_sites = max_sites)
poscars = []
for cr in crys:
cif = cr.cif
lines = cif.split('\n')
if '' in lines.copy():
lines.remove('')
abc = []
abc_ang = []
index = 0
for ys in lines:
if '_cell' in ys:
if '_length' in ys:
abc.append(ys.split()[1])
if '_angle' in ys:
abc_ang.append(ys.split()[1])
if '_volume' in ys:
volume = float(ys.split()[1])
if '_structural' in ys:
system = ys.split()[1]
for i,ys in enumerate(lines):
if '_atom_site_occupancy' in ys:
index = i +1 # start collecting pos.
poses = lines[index:]
pos_str = ""
for pos in poses:
s_p = pos.split()
pos_str += "{0:>12} {1:>12} {2:>12} {3}\n".format(*s_p[3:6],s_p[0])
# ======== Cleaning ===========
abc_ang = [float(ang) for ang in abc_ang]
abc = [float(a) for a in abc]
a = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(1.0,0.0,0.0) # lattic vector a.
to_rad = 0.017453292519
gamma = abc_ang[2]*to_rad
bx,by = abc[1]*np.cos(gamma),abc[1]*np.sin(gamma)
b = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(bx/abc[0],by/abc[0],0.0) # lattic vector b.
cz = volume/(abc[0]*by)
cx = abc[2]*np.cos(abc_ang[1]*to_rad)
cy = (abc[1]*abc[2]*np.cos(abc_ang[0]*to_rad)-bx*cx)/by
c = "{0:>22.16f} {1:>22.16f} {2:>22.16f}".format(cx/abc[0],cy/abc[0],cz/abc[0]) # lattic vector b.
unit = cr.unit.to_dict()
elems = [elem for elem in unit.keys()]
elems = '\t'.join(elems)
nums = [str(int(unit[elem])) for elem in unit.keys()]
nums = '\t'.join(nums)
# Get others important info
symbol = cr.symbol
crystal = cr.crystal
mp_id = cr.mp_id
# ======================
top = system + "\t# [{0}] Generated by PivotPy using Materials Project Database.".format(symbol)
poscar= "{0}\n {1}\n {2}\n {3}\n {4}\n {5}\n {6}\nDirect\n{7}".format(top,abc[0],a,b,c,elems,nums,pos_str)
# =======================
net_out = dict(system = system, symbol = symbol, crystal =crystal,mp_id = mp_id,poscar = poscar)
poscars.append(vp.Dict2Data(net_out))
return poscars
# Cell
[docs]def get_kpath(hsk_list=[],labels=[], n = 5,weight= None ,ibzkpt = None,outfile=None):
"""
- Generate list of kpoints along high symmetry path. Options are write to file or return KPOINTS list. It generates uniformly spaced point with input `n` as just a scale factor of number of points per unit length. You can also specify custom number of kpoints in an interval by putting number of kpoints as 4th entry in left kpoint.
- **Parameters**
- hsk_list : N x 3 list of N high symmetry points, if broken path then [[N x 3],[M x 3],...]. Optionally you can put a 4 values point where 4th entry will decide number of kpoints in current interval. Make sure that points in a connected path patch are at least two i.e. `[[x1,y1,z1],[x2,y2,z2]]` or `[[x1,y1,z1,N],[x2,y2,z2]]`.
- n ; int, number per unit length, this makes uniform steps based on distance between points.
- weight : Float, if None, auto generates weights.
- gamma : If True, shifts mesh at gamma point.
- ibzkpt : Path to ibzkpt file, required for HSE calculations.
- labels : Hight symmetry points labels. Good for keeping record of lables and points indices for later use. - Note: If you do not want to label a point, label it as 'skip' at its index and it will be removed.
- outfile: Path/to/file to write kpoints.
- **Attributes**
- If `outfile = None`, a tuple is returned which consists of:
- nkpts : get_kmesh().nkpts.
- kpoints : get_kmesh().kpoints.
- weights : get_kmesh().weights.
"""
if hsk_list:
try: hsk_list[0][0][0]
except: hsk_list = [hsk_list] # Make overall 3 dimensions to include breaks in path
xs,ys,zs, inds,joinat = [],[],[],[0],[] # 0 in inds list is important
for j,a in enumerate(hsk_list):
for i in range(len(a)-1):
_vec = [_a-_b for _a,_b in zip(a[i][:3],a[i+1] )] # restruct point if 4 entries
_m = np.rint(np.linalg.norm(_vec)*n).astype(int)
try: _m = a[i][3] # number of points given explicitly.
except: pass
inds.append(inds[-1]+_m)
if j !=0:
joinat.append(inds[-2]) # Add previous in joinpath
xs.append(list(np.linspace(a[i][0],a[i+1][0],_m)))
ys.append(list(np.linspace(a[i][1],a[i+1][1],_m)))
zs.append(list(np.linspace(a[i][2],a[i+1][2],_m)))
xs = [y for z in xs for y in z] #flatten values.
ys = [y for z in ys for y in z]
zs = [y for z in zs for y in z]
if weight == None:
weight = 1/len(xs)
out_str = ["{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x,y,z,weight) for x,y,z in zip(xs,ys,zs)]
out_str = '\n'.join(out_str)
N = np.size(xs)
if ibzkpt != None:
if os.path.isfile(ibzkpt):
f = open(ibzkpt,'r')
lines = f.readlines()
f.close()
N = int(lines[1].strip())+N # Update N.
slines = lines[3:N+4]
ibz_str = ''.join(slines)
out_str = "{}\n{}".format(ibz_str,out_str) # Update out_str
if inds:
inds[-1] = -1 # last index to -1
# Remove indices and labels where 'skip' appears
inds = [i for i,l in zip(inds,labels) if 'skip' not in l]
labels = [l for l in labels if 'skip' not in l]
top_str = "Automatically generated using PivotPy with HSK-INDS = {}, LABELS = {}, SEG-INDS = {}\n\t{}\nReciprocal Lattice".format(inds,labels,joinat,N)
out_str = "{}\n{}".format(top_str,out_str)
if outfile != None:
f = open(outfile,'w')
f.write(out_str)
f.close()
else:
mesh = namedtuple('Mesh',['nkpts','kpoints','weights'])
return mesh(N,np.array([[x,y,x] for x,y,z in zip(xs,ys,zs)]),[weight for x in xs])
# Cell
[docs]def get_kmesh(n_xyz=[5,5,5],weight = None,gamma=True,ibzkpt= None,poscar=None,outfile=None,plot=False):
"""
- Generate uniform mesh of kpoints. Options are write to file, plot or return KPOINTS list.
- **Parameters**
- n_xyz : List of [nx ny nz] or integer. If integere given, kmesh is autoscaled.
- weight : Float, if None, auto generates weights.
- gamma : Default True, shifts mesh at gamma point.
- ibzkpt : Path to ibzkpt file, required for HSE calculations.
- poscar : POSCAR file or real space lattice vectors, if None, cubic symmetry is used and it is fast.
- outfile: Path/to/file to write kpoints.
- plot : If True, returns interactive plot. You can look at mesh before you start calculation.
- **Attributes**
- If `plot = False`, a tuple is returned which consists of:
- nkpts : get_kmesh().nkpts.
- kpoints : get_kmesh().kpoints.
- weight : get_kmesh().weight, its one float number, provided or calculated.
"""
if type(n_xyz) == int:
nx,ny,nz = n_xyz,n_xyz,n_xyz
else:
nx,ny,nz = [n for n in n_xyz]
sx,sy,sz = 0.5,0.5,0.5
if poscar != None:
BZ = get_bz(poscar)
vs = BZ.vertices
bs = BZ.basis
vs = np.array([np.linalg.solve(bs.T,v) for v in vs]) # in b1,b2,b3 space space
sx,sy,sz = [np.max(vs[:,0]),np.max(vs[:,1]),np.max(vs[:,2])]
nb1,nb2,nb3 = [np.linalg.norm(b) for b in [sx,sy,sz]]
if np.min([nb1,nb2,nb3]) == nb3 and type(n_xyz) == int:
nz, nx, ny= n_xyz, np.rint(nb1/nb3*n_xyz).astype(int),np.rint(nb2/nb3*n_xyz).astype(int)
elif np.min([nb1,nb2,nb3]) == nb2 and type(n_xyz) == int:
ny, nx, nz= n_xyz, np.rint(nb1/nb2*n_xyz).astype(int),np.rint(nb3/nb2*n_xyz).astype(int)
elif np.min([nb1,nb2,nb3]) == nb1 and type(n_xyz) == int:
nx, ny, nz= n_xyz, np.rint(nb2/nb1*n_xyz).astype(int),np.rint(nb3/nb1*n_xyz).astype(int)
# Make center at gamma if True
tx,ty,tz = 0,0,0 # Translations in each dirs.
if gamma == True:
tx,ty,tz = [np.min(np.abs(np.linspace(-sx,sx,nx))),
np.min(np.abs(np.linspace(-sy,sy,ny))),
np.min(np.abs(np.linspace(-sz,sz,nz)))]
points = []
for i in np.linspace(-sx+tx,sx+tx,nx):
if nx==1:
i = i+sx-tx
for j in np.linspace(-sy+ty,sy+ty,ny):
if ny==1:
j = j+sy-ty
for k in np.linspace(-sz+tz,sz+tz,nz):
if nz==1:
k = k+sz-tz
# Handle BZ when no POSCAR given and grid shifted for gamma = True
if i <= sx and j <= sy and k <= sz:
points.append([i,j,k])
points = np.array(points)
points[np.abs(points) < 1e-10] = 0
sel_pts = [] # placeholder
top_info=''
if poscar != None:
bz = get_bz(poscar)
top_info = ' filtered in 1st BZ of rec_basis = [{}, {}, {}]'.format(*bz.basis) # first space is must.
vs = np.array([np.linalg.solve(bs.T,v) for v in bz.vertices])
h1 = ConvexHull(vs)
for p in points:
h2 = ConvexHull([*vs,p])
if math.isclose(h2.volume, h1.volume):
sel_pts.append(p)
points = np.array(sel_pts)
if len(points) == 0:
return print('No KPOINTS in BZ from given input. Try larger input!')
if weight == None and len(points) != 0:
weight = float(1/len(points))
out_str = ["{0:>16.10f}{1:>16.10f}{2:>16.10f}{3:>12.6f}".format(x,y,z,weight) for x,y,z in points]
out_str = '\n'.join(out_str)
N = len(points)
if ibzkpt != None:
if os.path.isfile(ibzkpt):
f = open(ibzkpt,'r')
lines = f.readlines()
f.close()
N = int(lines[1].strip())+N # Update N.
slines = lines[3:N+4]
ibz_str = ''.join(slines)
out_str = "{}\n{}".format(ibz_str,out_str) # Update out_str
top_str = "Automatically generated uniform mesh using PivotPy with {}x{}x{} grid{}\n\t{}\nReciprocal Lattice".format(nx,ny,nz,top_info,N)
out_str = "{}\n{}".format(top_str,out_str)
if outfile != None:
f = open(outfile,'w')
f.write(out_str)
f.close()
if plot == True:
print('NKPTS: ',N)
if poscar == None:
poscar = [[1,0,0],[0,1,0],[0,0,1]]
BZ = get_bz(poscar)
bs = BZ.basis
data= np.array([d[0]*bs[0]+d[1]*bs[1]+d[2]*bs[2] for d in points])
fig = go.Figure()
for i,b in enumerate(bs):
fig.add_trace(go.Scatter3d(x=[0,b[0]], y=[0,b[1]],z=[0,b[2]],mode='lines',name='b<sub>{}</sub>'.format(i+1)))
fig.add_trace(go.Scatter3d(x=data[:,0], y=data[:,1],z=data[:,2],mode='markers',name='KPOINTS',marker_size=3))
for j,pts in enumerate(BZ.faces):
pts = np.array(pts)
fig.add_trace(go.Scatter3d(x=pts[:,0], y=pts[:,1],z=pts[:,2],mode='lines',surfaceaxis=1,line_color='rgba(130,210,110,0.6)',name = 'F<sub>{}</sub>'.format(j+1) ))
camera = dict(
center=dict(x=0.1, y=0.1, z=0.1))
fig.update_layout(scene_camera=camera,plot_bgcolor='rgb(255,255,255)')
return fig
else:
mesh = namedtuple('Mesh',['nkpts','kpoints','weight'])
return mesh(N,points,weight)
# Cell
def intersect_3p_p_3v(a,b,c):
"""
- Returns intersection point of 3 planes in 3D.
- **Parameters**
- a,b,c : three vectors in 3D, whose perpendicular planes will be intersected.
"""
M = np.array([a,b,c])
b = np.array([np.linalg.norm(a)**2,np.linalg.norm(b)**2,np.linalg.norm(c)**2]).reshape(3,1)
out = []
if np.linalg.det(M) != 0:
out = np.linalg.solve(M,b).T
return out
def centroid(points):
"""
- Returns centroid of a list of 3D points.
- **Parameters**
- points: List[List(len=3)]
"""
_x = [p[0] for p in points]
_y = [p[1] for p in points]
_z = [p[2] for p in points]
_len = len(points)
_x = sum(_x) / _len
_y = sum(_y) / _len
_z = sum(_z) / _len
center = np.array([_x, _y,_z])
return center
def order(points):
"""
- Returns counterclockwise ordered vertices of a plane in 3D. Append first vertex at end to make loop.
- **Parameters**
- points: List[List(len=3)]
"""
center = centroid(points)
ex = [p-c for p,c in zip(points[0],center)]
ey = np.cross(center,ex)
ex = ex/np.linalg.norm(ex) # i
ey = ey/np.linalg.norm(ey) # j
angles= []
for i in range(0,len(points)):
v = [p-c for p,c in zip(points[i],center)]
vx = np.dot(v,ex)
vy = np.dot(v,ey)
#print(np.sign(vx),np.sign(vy))
theta = np.arctan(abs(vy/vx))
angle = theta # start
if np.sign(vx) == 1 and np.sign(vy) == 1:
angle = theta
if np.sign(vx) == -1 and np.sign(vy) == 1:
angle = np.pi - theta
if np.sign(vx) == -1 and np.sign(vy) == -1:
angle = np.pi + theta
if np.sign(vx) == 1 and np.sign(vy) == -1:
angle = 2*np.pi - theta
if vx == 0 and np.sign(vy) == -1:
angle = 3*np.pi/2
if vx == 0 and np.sign(vy) == 1:
angle = np.pi/2
if np.sign(vx) == -1 and vy == 0:
angle = np.pi
if np.sign(vx) == 1 and vy == 0:
angle = 2*np.pi
angles.append([i,angle])
#print(angles)
s_angs = np.array(angles)
ss= s_angs[s_angs[:,1].argsort()]
#print(ss)
o_pts =[]
for s in ss[:,0]:
pt = np.array(points[int(s)])
o_pts.append(pt)
o_pts.append(np.array(o_pts[0]))
return np.array(o_pts)
def in_vol_sector(test_point,p1,p2,p3):
"""
- Returns True if test_point lies inside/on the overlapping planes of three vectors.
- **Parameters**
- p1,p2,p3: Three vectors points in 3D.
"""
p_test = np.array(test_point)/np.linalg.norm(test_point)
p1 = np.array(p1)/np.linalg.norm(p1)
p2 = np.array(p2)/np.linalg.norm(p2)
p3 = np.array(p3)/np.linalg.norm(p3)
if np.dot(p1,np.cross(p2,p3)) == 0:
return True
else:
c = np.array(centroid([p1,p2,p3]))
_dot_test = np.dot(p_test-c,c)
if _dot_test < -1e-5:
return False
else:
return True
def out_bz_plane(test_point,plane):
"""
- Returns True if test_point is between plane and origin. Could be used to sample BZ mesh in place of ConvexHull.
- **Parameters**
- test_points: 3D point.
- plane : List of at least three coplanar points.
"""
outside = True
p_test = np.array(test_point)
plane = np.array(plane)
c = np.array(centroid(plane))
_dot_ = np.dot(p_test-c,c)
if _dot_ < -1e-5:
outside = False
return outside
def to_xy(v):
"""
- Rotate a 3D vector v in xy-plane.
- **Parameters**
- v: Ponit in 3D.
"""
x = v[0]
y = 1 # by default to hold if x and y both zero
if v[1] != 0 and v[2] !=0:
y = np.sqrt(v[1]**2+v[2]**2)
d = y
Rx = [[1,0,0],[0,v[1]/d,v[2]/d],[0,-v[2]/d,v[1]/d]]
v_out = np.dot(Rx,v) #v_out = [x,y,0]
return v_out
def rad_angle(v1,v2):
"""
- Returns interier angle between two vectors.
- **Parameters**
- v1,v2 : Two vectors/points in 3D.
"""
v1 = np.array(v1)
v2 = np.array(v2)
norm = np.linalg.norm(v1)*np.linalg.norm(v2)
dot_p = np.round(np.dot(v1,v2)/norm,12)
angle = np.arccos(dot_p)
return angle
def arctan_full(perp,base):
"""
- Returns full angle from x-axis counter clockwise.
- **Parameters**
- perp: Perpendicular componet of vector including sign.
- base: Base compoent of vector including sign.
"""
vy = perp
vx = base
angle = 0 # Place hodler to handle exceptions
if vx == 0 and np.sign(vy) == -1:
angle = 3*np.pi/2
elif vx == 0 and np.sign(vy) == 1:
angle = np.pi/2
else:
theta = abs(np.arctan(vy/vx))
if np.sign(vx) == 1 and np.sign(vy) == 1:
angle = theta
if np.sign(vx) == -1 and np.sign(vy) == 1:
angle = np.pi - theta
if np.sign(vx) == -1 and np.sign(vy) == -1:
angle = np.pi + theta
if np.sign(vx) == 1 and np.sign(vy) == -1:
angle = 2*np.pi - theta
if np.sign(vx) == -1 and vy == 0:
angle = np.pi
if np.sign(vx) == 1 and vy == 0:
angle = 2*np.pi
return angle
# Cell
[docs]def get_bz(poscar = None,loop = True,digits=8):
"""
- Return required information to construct first Brillouin zone in form of tuple (basis, normals, vertices, faces).
- **Parameters**
- poscar : POSCAR file or list of 3 vectors in 3D aslist[list,list,list].
- loop : If True, joins the last vertex of a BZ plane to starting vertex in order to complete loop.
- digits : int, rounding off decimal places, no effect on intermediate calculations, just for pretty final results
- **Attributes**
- basis : get_bz().basis, recprocal lattice basis.
- normals : get_bz().normals, all vertors that are perpendicular BZ faces/planes.
- vertices: get_bz().vertices, all vertices of BZ, could be input into ConvexHull for constructing 3D BZ.
- faces : get_bz().faces, vertices arranged into faces, could be input to Poly3DCollection of matplotlib for creating BZ from faces' patches.
- specials : get_bz().specials, Dictionary of high symmetry KPOINTS with keys as points relative to basis and values are corresponding positions in recirprocal coordinates space.
"""
if poscar == None and os.path.isfile('./POSCAR') == True:
poscar = './POSCAR'
elif poscar == None and os.path.isfile('./POSCAR') == False:
raise ValueError("Argument 'poscar' expects file 'POSCAR' or 3 basis vectors.")
lines = []
a1,a2,a3=[],[],[]
if np.ndim(poscar) ==2:
a1,a2,a3 = poscar[0],poscar[1],poscar[2]
elif os.path.isfile(poscar):
with open(poscar,'r') as f:
lines = f.readlines()
f.close()
if lines != []:
a1 = [float(i) for i in lines[2].split()]
a2 = [float(i) for i in lines[3].split()]
a3 = [float(i) for i in lines[4].split()]
else:
raise FileNotFoundError("'{}' does not exist or not 3 by 3 list.".format(poscar))
# Process
V = np.dot(a1,np.cross(a2,a3))
b1 = np.cross(a2,a3)/V
b2 = np.cross(a3,a1)/V
b3 = np.cross(a1,a2)/V
s_f= np.sqrt(np.dot(b1,b1))
b1,b2,b3 = b1/s_f,b2/s_f,b3/s_f # Normalize vectors
basis = np.array([b1,b2,b3])
# Get other vectors for BZ
vectors = []
for i,j,k in product([0,1,-1],[0,1,-1],[0,1,-1]):
vectors.append(i*b1+j*b2+k*b3)
vectors = np.array(vectors)
# Generate voronoi diagram
vor = Voronoi(vectors)
faces = []
vrd = vor.ridge_dict
for r in vrd:
if r[0] == 0 or r[1] == 0:
verts_in_face = np.array([vor.vertices[i] for i in vrd[r]])
faces.append(verts_in_face)
faces = np.array(faces)
verts = [v for vs in faces for v in vs]
verts = np.unique(verts,axis=0)
face_vectors = []
for f in faces:
face_vectors.append(2*centroid(f))
if loop == True:
faces = [order(face) for face in faces] # order in a loop
# High symmerty KPOINTS in primitive BZ (positive only)
mid_faces = np.array([centroid(np.unique(face,axis=0)) for face in faces])
mid_edges = []
for f in faces:
for i in range(len(f)-1):
# Do not insert point between unique vertices
if np.isclose(np.linalg.norm(f[i]),np.linalg.norm(f[i+1])):
mid_edges.append(centroid([f[i],f[i+1]]))
if mid_edges!=[]:
mid_edges = np.unique(mid_edges,axis=0) # because faces share edges
mid_faces = np.concatenate([mid_faces,mid_edges])
# Bring all high symmetry points together.
mid_all = np.concatenate([[[0,0,0]],mid_faces,verts])
mid_basis_all = np.array([np.linalg.solve(basis.T,v) for v in mid_all])
# Round off results
mid_all_p = np.round(mid_all,digits)
mid_basis_p = np.round(mid_basis_all,digits)
bais = np.round(basis,digits)
face_vectors = np.round(face_vectors,digits)
verts = np.round(verts,digits)
faces = [np.round(face,digits) for face in faces]
one_to_one = {tuple(x):tuple(y) for x,y in zip(mid_basis_p,mid_all_p)}
BZ = namedtuple('BZ', ['basis', 'normals','vertices','faces','specials'])
return BZ(basis,face_vectors,verts,np.array(faces),one_to_one)
# Cell
[docs]def plot_bz(poscar_or_bz = None,fill = True,color = 'rgba(168,204,216,0.4)',background = 'rgb(255,255,255)'):
"""
- Plots interactive figure showing axes,BZ surface, special points and basis, each of which could be hidden or shown.
- **Parameters**
- pocar_or_bz: POSCAR or 3 basis vectors' list forming POSCAR. Auto picks in working directory. Output of get_bz() also works.
- fill : True by defult, determines whether to fill surface of BZ or not.
- color : color to fill surface 'rgba((168,204,216,0.4)` by default.
- background : Plot background color, default is 'rgb(255,255,255)'.
- **Returns**
- fig : plotly.graph_object's Figure instance.
"""
if poscar_or_bz == None:
bz = get_bz(poscar_or_bz)
else:
try:
poscar_or_bz.basis
bz = poscar_or_bz
except AttributeError:
bz = get_bz(poscar_or_bz)
fig = go.Figure()
# Axes
fig.add_trace(go.Scatter3d(x=[0.25,0,0,0,0],y=[0,0,0.25,0,0],z=[0,0,0,0,0.25],
mode='lines+text',
text= ["<b>k</b><sub>x</sub>","","<b>k</b><sub>y</sub>","","<b>k</b><sub>z</sub>"],
line_color='green', legendgroup='Axes',name='Axes'))
fig.add_trace(go.Cone(x=[0.18,0,0],y=[0,0.18,0],z=[0,0,0.18],
u=[0.00001,0,0],v=[0,0.00001,0],w=[0,0, 0.00001],showscale=False,
colorscale='Greens',legendgroup='Axes',name='Axes'))
# Basis
for i,b in enumerate(bz.basis):
fig.add_trace(go.Scatter3d(x=[0,b[0]], y=[0,b[1]],z=[0,b[2]],
mode='lines+text',legendgroup="b<sub>{}</sub>".format(i+1), line_color='red',
name="<b>b</b><sub>{}</sub>".format(i+1),text=["","<b>b</b><sub>{}</sub>".format(i+1)]))
fig.add_trace(go.Cone(x=[0.95*b[0]],y=[0.95*b[1]],z=[0.95*b[2]],
u=[0.2*b[0]],v=[0.2*b[1]],w=[0.2*b [2]],showscale=False,colorscale='Reds',
legendgroup="b<sub>{}</sub>".format(i+1),name="<b>b</b><sub>{}</sub>".format(i+1)))
# Faces
face_ind = 0
fill_axis = None # Placeholder
legend = True
for k,pts in enumerate(bz.faces):
if fill == False:
color = 'black'
fill_axis = None
elif fill == True:
face_dir = np.abs(centroid(np.unique(pts,axis=0))) # same fill axis in negative axes too
if np.max(face_dir) == face_dir[0]:
fill_axis = 0
elif np.max(face_dir) == face_dir[1]:
fill_axis = 1
elif np.max(face_dir) == face_dir[2]:
fill_axis = 2
if k != 0:
legend = False
fig.add_trace(go.Scatter3d(x=pts[:,0], y=pts[:,1],z=pts[:,2],
mode='lines',line_color=color, legendgroup='BZ',name='BZ',
showlegend=legend,surfaceaxis=fill_axis))
# Special Points
texts,values =[],[]
for key,value in bz.specials.items():
norm = np.round(np.linalg.norm(value),5)
texts.append("P{}</br>d = {}".format(key,norm))
values.append([[*value,norm]])
values = np.array(values).reshape((-1,4))
norm_max = np.max(values[:,3])
c_vals = np.array([int(v*255/norm_max) for v in values[:,3]])
colors = [0 for i in c_vals]
_unique = np.unique(np.sort(c_vals))[::-1]
_lnp = np.linspace(0,255,len(_unique)-1)
_u_colors = ["rgb({},0,{})".format(r,b) for b,r in zip(_lnp,_lnp[::-1])]
for _un,_uc in zip(_unique[:-1],_u_colors):
_index = np.where(c_vals == _un)[0]
for _ind in _index:
colors[_ind]=_uc
colors[0]= "rgb(255,215,0)" # Gold color at Gamma!.
fig.add_trace(go.Scatter3d(x=values[:,0], y=values[:,1],z=values[:,2],
hovertext=texts,name="HSK",marker_color=colors,mode='markers'))
camera = dict(center=dict(x=0.1, y=0.1, z=0.1))
fig.update_layout(scene_camera=camera,paper_bgcolor=background,
font_family="Times New Roman",font_size= 14,
scene = dict(aspectmode='data',xaxis = dict(showbackground=False,visible=False),
yaxis = dict(showbackground=False,visible=False),
zaxis = dict(showbackground=False,visible=False)),
margin=dict(r=10, l=10,b=10, t=30))
return fig