Source code for reflectorch.inference.preprocess_exp.preprocess

from dataclasses import dataclass

import numpy as np

from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
from reflectorch.utils import angle_to_q


[docs] def standard_preprocessing( intensity: np.ndarray, scattering_angle: np.ndarray, attenuation: np.ndarray, q_interp: np.ndarray, wavelength: float, beam_width: float, sample_length: float, min_intensity: float = 1e-10, beam_shape: BEAM_SHAPE = "gauss", normalize_mode: NORMALIZE_MODE = "max", incoming_intensity: float = None, max_q: float = None, # if provided, max_angle is ignored max_angle: float = None, ) -> dict: """Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation Args: intensity (np.ndarray): array of intensities of the reflectivity curve scattering_angle (np.ndarray): array of scattering angles attenuation (np.ndarray): attenuation factors for each measured point q_interp (np.ndarray): reciprocal space points used for the interpolation wavelength (float): the wavelength of the beam beam_width (float): the beam width sample_length (float): the sample length min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10. beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss". normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max". incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None. max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None. max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None. Returns: dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting """ intensity = apply_attenuation_correction( intensity, attenuation, scattering_angle, ) intensity = apply_footprint_correction( intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape ) curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity) curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity) q = angle_to_q(scattering_angle, wavelength) q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength) curve_interp = interp_reflectivity(q_interp, q, curve) assert np.all(np.isfinite(curve_interp)) assert np.all(np.isfinite(curve)) assert np.all(np.isfinite(q)) assert np.all(np.isfinite(q_interp)) assert np.all(curve > 0.) assert np.all(curve_interp > 0.) return { "curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio, }
def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7): indices = (curve > thresh) & np.isfinite(curve) return curve[indices], scattering_angle[indices] @dataclass class StandardPreprocessing: q_interp: np.ndarray = None wavelength: float = 1. beam_width: float = None sample_length: float = None beam_shape: BEAM_SHAPE = "gauss" normalize_mode: NORMALIZE_MODE = "max" incoming_intensity: float = None def preprocess(self, intensity: np.ndarray, scattering_angle: np.ndarray, attenuation: np.ndarray, **kwargs ) -> dict: attrs = self._get_updated_attrs(**kwargs) return standard_preprocessing( intensity, scattering_angle, attenuation, **attrs ) __call__ = preprocess def set_parameters(self, **kwargs) -> None: for k, v in kwargs.items(): if k in self.__annotations__: setattr(self, k, v) else: raise KeyError(f'Unknown parameter {k}.') def _get_updated_attrs(self, **kwargs): current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()} current_attrs.update(kwargs) return current_attrs