Source code for ase2sprkkr.output_files.definitions.jxc

from ..output_files_definitions import OutputFileValueDefinition as V, OutputFileDefinition
from typing import Optional
import numpy as np
import re
import itertools
from enum import Enum
from matplotlib import cm
import warnings
from collections import namedtuple

from ..output_files import OutputFile
from ...common.decorators import cached_property, cached_class_property
from ...common.grammar_types import RawData, Table, Integer, Real, Array, Sequence, String, LineString, NumpyArray
from ...common.configuration_definitions import SeparatorDefinition
from ...gui.plot import Multiplot


[docs] class Coordinates(Enum): cartesian = "cartesian" lattice = "lattice"
Selector = namedtuple("Selector", ["iq", "it"])
[docs] class JXCOutputFile(OutputFile): plot_parameters = {"exchange_radius", "iq", "it", "exclude_it", "exclude_vc", "font_size", "axis", "separate_plots"}
[docs] def is_Jij(self): return len(self.DATA().dtype.names) == 15
[docs] def is_Dij(self): return len(self.DATA().dtype.names) == 14
[docs] @cached_property def iqs(self): return {i[0]: {d[0]: (d[1], d[2]) for d in i[2]} for i in self.OCCUPATION()}
[docs] @cached_property def it_labels(self): def generator(): for i in self.iqs.values(): for it, (label, occ) in i.items(): yield it, label return dict(generator())
[docs] def iq_to_its(self, iq): return self.iqs[iq].keys()
[docs] def it_to_iqs(self, it): return {iq for iq, types in self.iqs.items() if it in types}
[docs] @cached_property def labels_to_it(self): return {label: it for it, label in self.it_labels.items()}
[docs] def label_to_it(self, label): try: return self.labels_to_it[label] except KeyError as exc: raise ValueError( f"Unknown type label '{label}'. Valid labels are: {', '.join(self.labels_to_it.keys())}" ) from exc
[docs] def element_to_its(self, label: str): return [it for it, l in self.it_labels.items() if l.startswith(label) and (l == label or l[len(label)] == "_")]
@cached_class_property def _it_selector_regex(): return re.compile(r"\s*([A-Z][a-z]*(_(\d+)))?\s*")
[docs] def create_selector(self, selector=None, iq=None, it=None, exclude_it=None, exclude_vc=True): if selector is not None: return selector def parse_list(values): if isinstance(values, str): return values.split(",") if isinstance(values, int): return [values] return values def parse_it(values): if values is None or values == "": return ... values = parse_list(values) out = set() for v in values: if isinstance(v, int): out.add(v) continue match = self._it_selector_regex.fullmatch(v) if not match: try: index = int(v) if index not in self.it_labels: raise ValueError("Selector index out of range: {v}") out.add(index) except ValueError: raise ValueError(f"Invalid selector: {v}") elif match.group(2): out.add(self.label_to_it(match.group(1))) else: its = self.element_to_its(match.group(1)) if not its: raise ValueError(f"Selector {v} does not match any atomic type") out.update(its) return out def parse_iq(values): if values is None or values == "": return ... values = parse_list(values) try: return [int(v) for v in values] except Exception: raise ValueError(f"Invalid IQ selector: {values}. Must be (a comma-separated) list of integers.") iq = parse_iq(iq) it = parse_it(it) exclude_it = parse_it(exclude_it) if exclude_vc: exclude_vc = set(self.element_to_its("Vc")) if exclude_it is ...: exclude_it = exclude_vc else: exclude_it.update(exclude_vc) if exclude_it is not ...: exclude_it = parse_it(exclude_it) if it is ...: it = set(range(1, self.NT() + 1)) it = it - exclude_it return Selector(iq, it)
[docs] def it_selector(self, selector=None, iq=None, it=None, exclude_it=None, exclude_vc=True): if selector is None: selector = self.create_selector(iq=iq, it=it, exclude_vc=exclude_vc, exclude_it=exclude_it) iq, it = selector out = ... if iq is not ...: out = itertools.chain.from_iterable(self.iq_to_its(i) for i in iq) if it is not ...: out = it if out is ... else set(out).intersection(it) return out
[docs] def iq_selector(self, selector=None, iq=None, it=None, exclude_it=None, exclude_vc=True): if selector is None: selector = self.create_selector(iq=iq, it=it, exclude_it=exclude_it, exclude_vc=exclude_vc) iq, it = selector out = self.it_selector(selector=selector) if out is ...: return ... return set().union(*(self.it_to_iqs(it) for it in out))
[docs] def filtered_data(self, selector=None, iq=None, it=None, exclude_it=None, exclude_vc=True, exchange_radius=4.0): its = self.it_selector(selector=selector, iq=iq, it=it, exclude_it=exclude_it, exclude_vc=exclude_vc) data = self.DATA() if its is ...: mask = None else: its = list(its) mask = np.isin(data["IT"], its) mask &= np.isin(data["JT"], its) if exchange_radius is not None: if mask is not None: mask &= data["DR"] <= exchange_radius else: mask = data["DR"] <= exchange_radius if mask is not None: data = data[mask] return data
[docs] def _spin_moments(self): atomic_types = set() moments = [] for site in self.atoms.sites: for atomic_type in site.occupation: if atomic_type in atomic_types: continue atomic_types.add(atomic_type) moment = getattr(atomic_type.moments, "spin_moment", None) if moment is None: return None moments.append(moment) return moments
[docs] @staticmethod def _dij_components(component_axis): components = { "x": (["DX"], ["x"]), "y": (["DY"], ["y"]), "z": (["DZ"], ["z"]), "all": (["DX", "DY", "DZ"], ["x", "y", "z"]), } component_axis = {0: "x", 1: "y", 2: "z", None: "all"}.get(component_axis, component_axis) try: return components[component_axis] except KeyError as exc: raise ValueError(f"Invalid DMI axis '{component_axis}'. Use one of: all, x, y, z.") from exc
[docs] def plot( self, layout=2, figsize=None, latex=None, filename: Optional[str] = None, show: Optional[bool] = None, dpi=300, label_spacing=0.3, selector=None, iq=None, it=None, exclude_it=None, exclude_vc=True, exchange_radius=4.0, font_size=10, axis="all", separate_plots=False, layout_kind="constrained", **kwargs, ): def _resolve_layout(count, values): if isinstance(layout, int): if values > 1: return (count, values) cols = min(layout, values) rows = count // cols + (1 if count % cols else 0) return (max(rows, 1), cols) return layout labels = self.it_labels data = self.filtered_data( selector=selector, iq=iq, it=it, exclude_it=exclude_it, exclude_vc=exclude_vc, exchange_radius=exchange_radius, ) type_indexes = sorted(set(labels.keys())) if not type_indexes or not len(data): return False is_jij = self.is_Jij() x = data["DR"] if is_jij: def label(partner_index): return f"${labels[type_index]}$-${labels[partner_index]}$" def value_label(): return r"$J_{ij}$" def name(): return f"Jij_{labels[type_index]}" y = data[["JXX"]] spin_mom = self._spin_moments() if spin_mom is None: warnings.warn("Spin moments not found.") elif len(spin_mom) != self.NT(): warnings.warn( "Spin moments count does not match number of types - the potential do not matches the Jij data." ) spin_mom = None else: index, axis_labels = self._dij_components(axis) def label(partner_index): return f"{labels[type_index]}-{labels[partner_index]}" # $_{axis_labels[colindex]}$' def value_label(): return rf"$D_{{ij}}^{{{axis_labels[colindex]}}}$" def name(): return f"Dij_{labels[type_index]}" y = data[index] spin_mom = None extremum = max((np.abs(y[col]).max() for col in y.dtype.names)) extremum += max(0.1, extremum * 0.1) colorspace = cm.Paired(np.linspace(0, 1, 2 * self.NT() + 2)) layout = _resolve_layout(len(type_indexes), len(y.dtype.names)) def plot(axis): colors = iter(colorspace) for partner_index in type_indexes: mask = other == partner_index if not np.any(mask): continue xxx = xx[mask] yyy = col[mask] if spin_mom is not None: sign_ij = np.sign(spin_mom[type_index - 1]) * np.sign(spin_mom[partner_index - 1]) yyy *= sign_ij # cycle over DX,DY,DZ (or just plot one line for all other cases) color = next(colors) axis.plot(xxx, yyy, alpha=0.75, lw=2.0, color=color, label=label(partner_index)) axis.scatter(xxx, yyy, color=color, alpha=0.75, s=45, lw=1.0, edgecolor="black") axis.set_xlabel(r"$r_{ij}/a_{\mathrm{lat}}$", fontsize=font_size) axis.set_ylabel(value_label() + r" $[meV]$", fontsize=font_size) axis.tick_params(axis="x", colors="black", labelsize=font_size) axis.tick_params(axis="y", colors="black", labelsize=font_size) axis.axhline(0, color="black", linestyle="--") if label_spacing: axis.legend(fontsize=font_size - 2, loc="best", labelspacing=label_spacing) axis.grid(False) axis.set_ylim(-extremum, extremum) with Multiplot( layout=layout, figsize=figsize, latex=latex, filename=filename, show=show, dpi=dpi, separate_plots=separate_plots, layout_kind=layout_kind, **kwargs, ) as mp: for type_index in type_indexes: mask = data["IT"] == type_index if not np.any(mask): continue xx = x[mask] yy = y[mask] other = data["JT"][mask] for colindex, cname in enumerate(y.dtype.names): col = yy[cname] mp.plot(self, name=name(), plot_function=plot) return True
[docs] def write_uppasd_file( self, file_name=None, directory=None, selector=None, it=None, iq=None, exclude_it=None, exclude_vc=True, exchange_radius=None, coordinates: Coordinates = Coordinates.lattice, ): """Write Jij or Dij data to file in format suitable for UppAsd programm. If the filename is not given, the standard name for the given file type will be used. """ if self.is_Jij(): from ...bindings.uppasd import write_jfile as write else: from ...bindings.uppasd import write_dmfile as write write( self, file_name, directory=directory, selector=selector, iq=iq, it=it, exclude_it=exclude_it, exclude_vc=exclude_vc, exchange_radius=exchange_radius, coordinates=coordinates, )
[docs] @classmethod def from_atoms(cls, atoms, data=None): out = cls(definition=definition) its = {} def it(atomic_type): if atomic_type in its: return its[atomic_type] index = len(its) + 1 its[atomic_type] = index return index first = {} duplicates = {} for site in atoms.sites: for atomic_type in site.occupation: if atomic_type.symbol in first: if atomic_type.symbol not in duplicates: duplicates[atomic_type.symbol] = {first[atomic_type.symbol]: 1} duplicates[atomic_type.symbol][atomic_type] = len(duplicates[atomic_type.symbol]) + 1 else: first[atomic_type.symbol] = atomic_type def label(atomic_type): dups = duplicates.get(atomic_type.symbol, None) if dups is not None: return f"{atomic_type.symbol}_{dups[atomic_type]}" return atomic_type.symbol out.OCCUPATION = [ ( i + 1, len(site.occupation), [(it(atomic_type), label(atomic_type), occ) for atomic_type, occ in site.occupation.items()], ) for i, site in enumerate(atoms.sites) ] if data: out.DATA = data return out
[docs] class JXCOutputFileDefinition(OutputFileDefinition): result_class = JXCOutputFile
import pyparsing as pp pp.ParserElement.verbose_stacktrace = True
[docs] def create_definition(): def table_header(c): if c.is_Jij(): return "IT IQ JT JQ N1 N2 N3 DRX DRY DRZ DR J_xx [meV] J_yy [meV] J_xy [meV] J_yx [meV]" else: return "IT IQ JT JQ N1 N2 N3 DRX DRY DRZ DR DX_ij [meV] DY_ij [meV] DZ_ij [meV]" shared_dtype = [ ("IT", int), ("IQ", int), ("JT", int), ("JQ", int), ("N1", int), ("N2", int), ("N3", int), ("DRX", float), ("DRY", float), ("DRZ", float), ("DR", float), ] definition = JXCOutputFileDefinition( "JXC", [ V( "HEADER", RawData( ends_with=re.compile("\n[ \t]*number of sites"), default_value=""" ******************************************************************************* <XCPLTENSOR>: Dzyaloshinski-Moriya couplings Dij according to Phys. Rev. B 79, 045209 (2009) ******************************************************************************* """, ), name_in_grammar=False, ), V( "NQ", int, written_name="number of sites NQ", delimiter=" = ", delimiter_grammar="=" ), # , indent=10*" "), V("NT", int, written_name="number of types NT", delimiter=" = ", delimiter_grammar="=", indent=10 * " "), SeparatorDefinition(" site occupation:"), V( "OCCUPATION", Table( IQ=Integer(prefix="IQ = ", format="{:>3}"), NOQ=int, DATA=Array( Sequence(int, String(prefix="-", format="{:>10}"), Real(prefix="x = ", format="{:6.3f}")), prefix="IT:", ), header=False, ), ), V( "TABLE_HEADER", LineString(), default_value_from_container=table_header, is_hidden=True, name_in_grammar=False, ), V( "DATA", NumpyArray( dtypes=[ shared_dtype + [("JXX", float), ("JYY", float), ("JXY", float), ("JYX", float)], shared_dtype + [("DX", float), ("DY", float), ("DZ", float)], ] ), name_in_grammar=False, ), ], info="Dzyaloshinski-Moriya couplings Dij or Jij", ) definition.__dict__["extension"] = "dat" return definition
definition = create_definition()