from ..output_files_definitions import (
OutputFileValueDefinition as V,
create_output_file_definition,
OutputFileDefinition,
)
from typing import Optional
import numpy as np
import matplotlib
import copy
from ..output_files import CommonOutputFile, Arithmetic
from ...common.grammar_types import NumpyArray
from ...common.decorators import cached_property
from ...common.generated_configuration_definitions import NumpyViewDefinition as NV
from ...gui.plot import Multiplot, set_up_common_plot
from ase.units import Rydberg
from packaging.version import Version
[docs]
class DOS(Arithmetic):
[docs]
def __init__(self, energy, dos, type=None, id=None, spin=None, l=None): # NOQA
"""Object, that holds DOS for one atom
(given type and corresponding id id DOS file)
spin, l (s,p,d,f...)"""
self.id = id
self.type = type
self.energy = energy
self.spin = spin
self.l = l # NOQA
self.dos = dos
[docs]
def copy(self, copy_values=True):
out = copy.copy(self)
if copy_values:
out.dos = out.dos.copy()
return out
@property
def shape(self):
return self.dos.shape
def __getitem__(self, key):
return self.dos[key]
[docs]
def plot(self, axis=None, legend_ncols=1, legend_height=0.3, legend_fontsize=10, **kwargs):
axis.set_xlabel(r"$E-E_{\rm F}$ (eV)")
axis.set_ylabel(r"DOS (states/eV)")
title = []
if self.type:
title = [self.type]
if self.spin is not None:
title.append("spin {}".format({0: "up", 1: "down"}.get(self.spin, self.spin)))
if self.l is not None:
orbital = {0: "s", 1: "p", 2: "d", 3: "f"}.get(self.l, None)
title.append(orbital + "orbitals" if orbital else self.l)
if title:
axis.title.set_text(", ".join(title))
params = {
"s": {"color": "blue"},
"p": {"color": "green"},
"d": {"color": "orange"},
"f": {"color": "cyan"},
"total": {"color": "black"},
}
set_up_common_plot(axis, **kwargs)
def plot_l(data, spin, l): # NOQA
i = l if l in params else "total"
args = params[i]
(line,) = axis.plot(self.energy, data, label=l, **args)
return line
def plot_spin(data, spin, legend):
if spin == -1:
data = data * spin
if self.l is not None:
handles = [plot_l(data, spin, self.l)]
else:
handles = [plot_l(d, spin, l) for d, l in zip(data, ("s", "p", "d", "f"))]
if len(handles):
handles.append(plot_l(np.sum(data, axis=0), spin, "total"))
if legend and not self.l:
ncols = "ncols" if Version(matplotlib.__version__) >= Version("3.6") else "ncol"
axis.legend(
handles=handles,
loc="best",
fontsize=legend_fontsize,
**{ncols: legend_ncols},
handleheight=legend_height,
)
if self.spin is not None:
plot_spin(self.dos, self.spin or 1, True)
else:
legend = True
for d, s in zip(self.dos, (1, -1)):
plot_spin(d, s, legend=legend)
legend = False
[docs]
def _do_arithmetic(self, func, other):
if isinstance(other, DOS):
other = other.dos
getattr(self.dos, func)(other)
self.id = None
self.type = None
[docs]
def _check_arithmetic(self, other):
if self.energy is other.energy:
return
assert np.allclose(self.energy, other.energy)
def __repr__(self):
if self.type:
return f"DOS of {self.type}"
else:
return "DOS"
[docs]
class DOSOutputFile(CommonOutputFile):
[docs]
def __init__(self, definition, container=None):
super().__init__(definition, container)
self.ENERGY.add_hook(self._clear_computed)
self.EFERMI.add_hook(self._clear_computed)
[docs]
def _clear_computed(self, _):
if "energy" in self.__dict__:
del self.energy
[docs]
@cached_property
def energy(self):
return (self.ENERGY() - self.EFERMI()) * Rydberg
[docs]
def plot(
self,
spin=None,
l=None,
layout=2,
figsize=(6, 4),
latex=None, # NOQA
filename: Optional[str] = None,
show: Optional[bool] = None,
dpi=600,
**kwargs,
):
n_types = self.n_types()
if n_types > 1:
n_types += 1
if isinstance(layout, int):
layout = ((n_types - 1) // layout + 1, min(layout, n_types))
print(layout)
with Multiplot(
layout=layout, figsize=figsize, latex=latex, filename=filename, show=show, dpi=dpi, **kwargs
) as mp:
for dos in self.iterate_dos(spin, l, total=n_types > 1):
mp.plot(dos)
def __getitem__(self, name):
"""
In addition to the general Container __getitem__, the following
is possible>
``dos[0].plot()`` - plot the first atom
``dos['Te'].plot()`` - plot Te atom
"""
try:
return super().__getitem__(name)
except KeyError as ke:
if not np.issubdtype(name.__class__, np.integer):
for i, type in enumerate(self.TYPES):
if type["TXT_T"] == name:
name = i
break
else:
raise KeyError(f"There is no atomic type {name} nor such value in the DOS file") from ke
for i, slc in enumerate(self.iterate_data_slices()):
if i == name:
return self._create_dos(slc, i)
raise KeyError(f"There is no {name}th atomic type in the DOS file") from ke
[docs]
def total_dos(self, spin=None, ll=None):
return [i for i in self.iterate_dos(spin, ll, total=True)][-1]
def __iter__(self):
return self.iterate_dos(total=False)
[docs]
@staticmethod
def _resolve_spin(spin):
if isinstance(spin, str):
return {"up": 0, "down": 1}[spin.lower()]
return spin
[docs]
def iterate_data_slices(self): # NOQA
spins = self.n_spins()
end = 0
for t in self.TYPES():
start = end
orbitals = self.n_orbitals_for(t)
end += orbitals * spins
yield slice(start, end)
[docs]
def iterate_dos(self, spin=None, l=None, total=True): # NOQA
spin = self._resolve_spin(spin)
total = bool(total)
for i, slic in enumerate(self.iterate_data_slices()):
out = self._create_dos(slic, i, spin, l)
yield out
if total is False:
continue
ratio = self.TYPES[i]["CONC"] * len(self.TYPES[i]["IQAT"])
if total is True:
total = out * ratio
else:
total.dos += out.dos * ratio
if total:
total.type = "total"
yield total
[docs]
def index_of_dos_for_site_type(self, atom):
"""Return slice to the DOS array selecting the datas for a given site type"""
if isinstance(atom, str):
atom = self.site_type_index(atom)
for i, slic in enumerate(self.iterate_data_slices()):
if i == atom:
return slic
[docs]
def dos_for_site_type(self, atom, spin=None, l=None): # NOQA
"""Return density of states for a given atom,
indexed either by integer index, or a string type.
The resulting array is indexed by: [l, spin, energy], however,
it can be restricted to given spin and/or l by arguments.
"""
spin = self._resolve_spin(spin)
if isinstance(atom, str):
atom = self.site_type_index(atom)
key = self.index_of_dos_for_site_type(atom)
return self._create_dos(key, atom, spin, l)
[docs]
def _create_dos(self, key, id, spin=None, l=None): # NOQA
type = self.TYPES[id]["TXT_T"]
out = self.DOS[key]
out = out.reshape((-1, self.n_orbitals_for(self.TYPES[id]), out.shape[1]))
# out = np.moveaxis(out, 1, 0)
if l is not None:
out = out[:, l]
if spin is not None:
out = out[spin]
return DOS(self.energy, out / Rydberg, type, id, spin, l)
[docs]
def n_orbitals_for(self, type):
"""
Return the number of orbitals for the given type record
"""
return max(self.ORBITALS[iq - 1]["NLQ"] for iq in type["IQAT"])
[docs]
def n_spins(self):
"""
Return the number of spins for each orbital
"""
ln = len(self.DOS())
orbitals = sum(self.n_orbitals_for(t) for t in self.TYPES())
return ln // orbitals
[docs]
class DOSDefinition(OutputFileDefinition):
result_class = DOSOutputFile
[docs]
def create_definition():
def i(j):
return slice(None), j
definition = create_output_file_definition(
"DOS",
[
V("DOS-FMT", str, written_name="DOS-FMT:"),
V(
"RAW_DATA",
NumpyArray(written_shape=(-1, 8), delimiter=10, indented=(80, 10), item_format="%8.4E"),
name_in_grammar=False,
),
NV("ENERGY", "RAW_DATA", i(0), info="Energies"),
NV("Y", "RAW_DATA", i(1), info="Y"),
NV(
"DOS",
"RAW_DATA",
i(slice(2, None)),
("NE", -1),
transpose=True,
transform_key=lambda k, dos: dos.index_of_dos_for_site_type(k) if isinstance(k, str) else k,
info="Desntity of states",
description="Density of states, the leading dimension iterates: "
"1..n_atoms, 1..n_spins, 1..n_orbitals(atoms) ",
),
],
cls=DOSDefinition,
info="Result of a DOS (density of states) calculation.",
)
return definition
definition = create_definition()