Coverage for debye_calculator.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import os
2import sys
3import base64
4import yaml
5import pkg_resources
6import warnings
7from glob import glob
8from datetime import datetime
9from typing import Union, Tuple, Any, List
11import torch
12from torch import cdist
13from torch.nn.functional import pdist
15import numpy as np
16import matplotlib.pyplot as plt
18from ase import Atoms
19from ase.io import read, write
20from ase.build import make_supercell
21from ase.build.tools import sort as ase_sort
23from DebyeCalculator.utility.profiling import Profiler
25import ipywidgets as widgets
26from IPython.display import display, HTML, clear_output
27from ipywidgets import HBox, VBox, Layout
28from tqdm.auto import tqdm
30import collections
31import timeit
33class DebyeCalculator:
34 """
35 Calculate the scattering intensity I(q) through the Debye scattering equation, the Total Scattering Structure Function S(q),
36 the Reduced Total Scattering Function F(q), and the Reduced Atomic Pair Distribution Function G(r) for a given atomic structure.
39 Parameters:
40 qmin (float): Minimum q-value for the scattering calculation. Default is 1.0.
41 qmax (float): Maximum q-value for the scattering calculation. Default is 30.0.
42 qstep (float): Step size for the q-values in the scattering calculation. Default is 0.1.
43 qdamp (float): Damping parameter caused by the truncated Q-range of the Fourier transformation. Default is 0.04.
44 rmin (float): Minimum r-value for the pair distribution function (PDF) calculation. Default is 0.0.
45 rmax (float): Maximum r-value for the PDF calculation. Default is 20.0.
46 rstep (float): Step size for the r-values in the PDF calculation. Default is 0.01.
47 rthres (float): Threshold value for exclusion of distances below this value in the scattering calculation. Default is 0.0.
48 biso (float): Debye-Waller isotropic atomic displacement parameter. Default is 0.3.
49 device (str): Device to use for computation (e.g., 'cuda' for GPU or 'cpu' for CPU). Default is 'cuda' if the computer has a GPU.
50 batch_size (int or None): Batch size for computation. If None, the batch size will be automatically set. Default is None.
51 lorch_mod (bool): Flag to enable Lorch modification. Default is False.
52 radiation_type (str): Type of radiation for form factor calculations ('xray' or 'neutron'). Default is 'xray'.
53 profile (bool): Activate profiler. Default is False.
54 """
56 def __init__(
57 self,
58 qmin: float = 1.0,
59 qmax: float = 30.0,
60 qstep: float = 0.1,
61 qdamp: float = 0.04,
62 rmin: float = 0.0,
63 rmax: float = 20.0,
64 rstep: float = 0.01,
65 rthres: float = 0.0,
66 biso: float = 0.3,
67 device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
68 batch_size: Union[int, None] = 10000,
69 lorch_mod: bool = False,
70 radiation_type: str = 'xray',
71 profile: bool = False,
72 _max_batch_size: int = 4000,
73 ) -> None:
75 self.profile = profile
76 if self.profile:
77 self.profiler = Profiler()
79 # Initial parameters
80 self.device = device
81 self.batch_size = batch_size
82 self.lorch_mod = lorch_mod
83 self.radiation_type = radiation_type
85 # Standard Debye parameters
86 self.qmin = qmin
87 self.qmax = qmax
88 self.qstep = qstep
89 self.qdamp = qdamp
90 self.rmin = rmin
91 self.rmax = rmax
92 self.rstep = rstep
93 self.rthres = rthres
94 self.biso = biso
96 # Initialise ranges
97 self.q = torch.arange(self.qmin, self.qmax, self.qstep).unsqueeze(-1).to(device=self.device)
98 self.r = torch.arange(self.rmin, self.rmax, self.rstep).unsqueeze(-1).to(device=self.device)
100 # Form factor coefficients
101 with open(pkg_resources.resource_filename(__name__, 'form_factor_coef.yaml'), 'r') as yaml_file:
102 self.FORM_FACTOR_COEF = yaml.safe_load(yaml_file)
104 # Formfactor retrieval lambda
105 for k,v in self.FORM_FACTOR_COEF.items():
106 if None in v:
107 v = [value if value is not None else np.nan for value in v]
108 self.FORM_FACTOR_COEF[k] = torch.tensor(v).to(device=self.device, dtype=torch.float32)
109 if radiation_type.lower() in ['xray', 'x']:
110 self.form_factor_func = lambda p: torch.sum(p[:5] * torch.exp(-1*p[6:11] * (self.q / (4*torch.pi)).pow(2)), dim=1) + p[5]
111 elif radiation_type.lower() in ['neutron', 'n']:
112 self.form_factor_func = lambda p: p[11].unsqueeze(-1)
114 # Batch size
115 self._max_batch_size = _max_batch_size
117 def __repr__(
118 self,
119 ):
120 parameters = {'qmin': self.qmin, 'qmax': self.qmax, 'qdamp': self.qdamp, 'qstep': self.qstep,
121 'rmin': self.rmin, 'rmax': self.rmax, 'rstep': self.rstep, 'rthres': self.rthres,
122 'biso': self.biso}
124 return f"DebyeCalculator{parameters}"
126 def update_parameters(
127 self,
128 **kwargs: Any,
129 ) -> None:
130 """
131 Set or update the parameters of the DebyeCalculator.
133 Parameters:
134 **kwargs: Arbitrary keyword arguments to update the parameters.
135 """
136 for k,v in kwargs.items():
137 try:
138 setattr(self, k, v)
139 except:
140 print("Failed to update parameters because of unexpected parameter names")
141 return
143 # Re-initialise ranges
144 if np.any([k in ['qmin','qmax','qstep','rmin', 'rmax', 'rstep'] for k in kwargs.keys()]):
145 self.q = torch.arange(self.qmin, self.qmax, self.qstep).unsqueeze(-1).to(device=self.device)
146 self.r = torch.arange(self.rmin, self.rmax, self.rstep).unsqueeze(-1).to(device=self.device)
148 def _initialise_structures(
149 self,
150 structure_path: Union[str, Atoms, List[Atoms]],
151 radii: Union[List[float], float, None] = None,
152 disable_pbar: bool = False,
153 ) -> None:
155 """
156 Initialise atomic structures and unique elements form factors from an input file.
158 Parameters:
159 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
160 radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF
161 """
162 # Check if input is a file or ASE Atoms object
163 if isinstance(structure_path, str):
164 # Check file and extention
165 structure_ext = structure_path.split('.')[-1]
166 if structure_ext not in ['xyz', 'cif']:
167 raise TypeError('FAILED: Invalid file/file-extention, accepts only .xyz or .cif data files')
168 elif isinstance(structure_path, Atoms) or all(isinstance(lst_elm, Atoms) for lst_elm in structure_path):
169 structure_ext = 'ase'
170 else:
171 raise TypeError('FAILED: Invalid structure format, accepts only .xyz, .cif data files or ASE Atoms objects')
173 # If cif, check for radii and generate particles
174 if structure_ext == 'cif':
175 if radii is not None:
176 ase_structures, _ = self.generate_nanoparticles(structure_path, radii, disable_pbar=disable_pbar)
177 self.num_structures = len(ase_structures)
178 else:
179 raise ValueError('FAILED: When providing .cif data file, please provide radii (Union[List[float], float]) to generate from.')
181 self.struc_elements = []
182 self.struc_size = []
183 self.struc_occupancy = []
184 self.struc_xyz = []
186 for structure in ase_structures:
187 elements = structure.get_chemical_symbols()
188 size = len(elements)
189 occupancy = torch.ones((size), dtype=torch.float32).to(device=self.device)
190 xyz = torch.tensor(np.array(structure.get_positions())).to(device=self.device, dtype=torch.float32)
192 self.struc_elements.append(elements)
193 self.struc_size.append(size)
194 self.struc_occupancy.append(occupancy)
195 self.struc_xyz.append(xyz)
197 elif structure_ext == 'xyz':
199 self.num_structures = 1
200 struc = np.genfromtxt(structure_path, dtype='str', skip_header=2) # Gen
201 self.struc_elements = [struc[:,0]] # Identities
202 self.struc_size = [len(self.struc_elements[0])] # Size
204 # Append occupancy if nothing is provided
205 if struc.shape[1] == 5:
206 self.struc_occupancy = [torch.from_numpy(struc[:,-1]).to(device=self.device, dtype=torch.float32)]
207 self.struc_xyz = [torch.tensor(struc[:,1:-1].astype('float')).to(device=self.device, dtype=torch.float32)]
208 else:
209 self.struc_occupancy = [torch.ones((self.struc_size[0]), dtype=torch.float32).to(device=self.device)]
210 self.struc_xyz = [torch.tensor(struc[:,1:].astype('float')).to(device=self.device, dtype=torch.float32)]
211 elif structure_ext == 'ase':
212 if isinstance(structure_path, Atoms):
213 ase_structures = [structure_path]
214 else:
215 ase_structures = structure_path
217 self.num_structures = len(ase_structures)
219 self.struc_elements = []
220 self.struc_size = []
221 self.struc_occupancy = []
222 self.struc_xyz = []
224 for structure in ase_structures:
225 elements = structure.get_chemical_symbols()
226 size = len(elements)
227 occupancy = torch.ones((size), dtype=torch.float32).to(device=self.device)
228 xyz = torch.tensor(np.array(structure.get_positions())).to(device=self.device, dtype=torch.float32)
230 self.struc_elements.append(elements)
231 self.struc_size.append(size)
232 self.struc_occupancy.append(occupancy)
233 self.struc_xyz.append(xyz)
234 else:
235 raise TypeError('FAILED: Invalid structure format, accepts only .xyz, .cif data files or ASE Atoms objects')
237 # Unique elements and their counts
238 self.triu_indices = []
239 self.unique_inverse = []
240 self.struc_unique_form_factors = []
241 self.struc_form_avg_sq = []
242 self.struc_inverse = []
244 for i in range(self.num_structures):
246 # Get unique elements and construc form factor stacks
247 unique_elements, inverse, counts = np.unique(self.struc_elements[i], return_counts=True, return_inverse=True)
249 triu_indices = torch.triu_indices(self.struc_size[i], self.struc_size[i], 1)
250 unique_inverse = torch.from_numpy(inverse[triu_indices]).to(device=self.device)
251 struc_unique_form_factors = torch.stack([self.form_factor_func(self.FORM_FACTOR_COEF[el]) for el in unique_elements])
253 self.triu_indices.append(triu_indices)
254 self.unique_inverse.append(unique_inverse)
255 self.struc_unique_form_factors.append(struc_unique_form_factors)
257 # Calculate average squared form factor and self scattering inverse indices
258 counts = torch.from_numpy(counts).to(device=self.device)
259 compositional_fractions = counts / torch.sum(counts)
260 struc_form_avg_sq = torch.sum(compositional_fractions.reshape(-1,1) * struc_unique_form_factors, dim=0)**2
261 struc_inverse = torch.from_numpy(np.array([inverse[i] for i in range(self.struc_size[i])])).to(device=self.device)
263 self.struc_form_avg_sq.append(struc_form_avg_sq)
264 self.struc_inverse.append(struc_inverse)
266 def iq(
267 self,
268 structure_path: Union[str, Atoms, List[Atoms]],
269 radii: Union[List[float], float, None] = None,
270 keep_on_device: bool = False,
271 _total_scattering: bool = False,
272 ) -> Union[Tuple[np.float32, Union[List[np.float32], np.float32]], Tuple[torch.FloatTensor, Union[List[torch.FloatTensor], torch.FloatTensor]]]:
274 """
275 Calculate the scattering intensity I(Q) for the given atomic structure.
277 Parameters:
278 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
279 radii (Union[List[float], float, None]): List/float of radii/radius of particle(s) to generate with parsed CIF
280 keep_on_device (bool): Flag to keep the results on the class device. Default is False.
281 _total_scattering (bool): Flag to return the scattering intensity I(Q) without the self-scattering contribution. Default is False.
283 Returns:
284 Tuple of torch tensors containing Q-values and scattering intensity I(Q) if keep_on_device is True, otherwise, numpy arrays on CPU.
285 """
287 # Raises errors if wrong path or parameters
288 if not os.path.exists(structure_path):
289 raise FileNotFoundError(f"{structure_path} not found.")
290 if self.qmin < 0:
291 raise ValueError("qmin must be non-negative.")
292 if self.qmax < 0:
293 raise ValueError("qmax must be non-negative.")
294 if self.qstep < 0:
295 raise ValueError("qstep must be non-negative.")
296 if self.qdamp < 0:
297 raise ValueError("qdamp must be non-negative.")
298 if self.rmin < 0:
299 raise ValueError("rmin must be non-negative.")
300 if self.rmax < 0:
301 raise ValueError("rmax must be non-negative.")
302 if self.rstep < 0:
303 raise ValueError("rstep must be non-negative.")
304 if self.rthres < 0:
305 raise ValueError("rthres must be non-negative.")
306 if self.biso < 0:
307 raise ValueError("biso must be non-negative.")
308 if self.batch_size is not None and self.batch_size < 0:
309 raise ValueError("batch_size must be non-negative.")
311 # Initialise structure
312 self._initialise_structures(structure_path, radii, disable_pbar = True)
314 if self.profile:
315 self.profiler.time('Setup structures and form factors')
317 # Calculate I(Q) for all initialised structures
318 iq_output = []
319 for i in range(self.num_structures):
321 # Calculate distances and batch
322 if self.batch_size is None:
323 self.batch_size = self._max_batch_size
324 dists = pdist(self.struc_xyz[i]).split(self.batch_size)
325 indices = self.triu_indices[i].split(self.batch_size, dim=1)
326 inverse_indices = self.unique_inverse[i].split(self.batch_size, dim=1)
328 if self.profile:
329 self.profiler.time('Batching and Distances')
331 # Calculate scattering using Debye Equation
332 iq = torch.zeros((len(self.q))).to(device=self.device, dtype=torch.float32)
333 for d, inv_idx, idx in zip(dists, inverse_indices, indices):
334 mask = d >= self.rthres
335 occ_product = self.struc_occupancy[i][idx[0]] * self.struc_occupancy[i][idx[1]]
336 sinc = torch.sinc(d[mask] * self.q / torch.pi)
337 ffp = self.struc_unique_form_factors[i][inv_idx[0]] * self.struc_unique_form_factors[i][inv_idx[1]]
338 iq += torch.sum(occ_product.unsqueeze(-1) * ffp * sinc.permute(1,0), dim=0)
340 # Apply Debye-Weller Isotropic Atomic Displacement
341 if self.biso != 0.0:
342 iq *= torch.exp(-self.q.squeeze(-1).pow(2) * self.biso/(8*torch.pi**2))
344 # For total scattering
345 if _total_scattering:
346 if self.profile:
347 self.profiler.time('I(Q)')
348 iq_output.append(iq) # TODO Times 2
349 continue
351 # Self-scattering contribution
352 sinc = torch.ones((self.struc_size[i], len(self.q))).to(device=self.device)
353 iq += torch.sum((self.struc_occupancy[i].unsqueeze(-1) * self.struc_unique_form_factors[i][self.struc_inverse[i]])**2 * sinc, dim=0) / 2
354 iq *= 2
356 if self.profile:
357 self.profiler.time('I(Q)')
359 iq_output.append(iq)
361 if _total_scattering:
362 return self.q.squeeze(-1), iq_output
364 if keep_on_device:
365 if self.num_structures == 1:
366 return self.q.squeeze(-1), iq_output[0]
367 else:
368 return self.q.squeeze(-1), iq_output
369 else:
370 if self.num_structures == 1:
371 return self.q.squeeze(-1).cpu().numpy(), iq_output[0].cpu().numpy()
372 else:
373 return self.q.squeeze(-1).cpu().numpy(), [iq.cpu().numpy() for iq in iq_output]
375 def sq(
376 self,
377 structure_path: Union[str, Atoms, List[Atoms]],
378 radii: Union[List[float], float, None] = None,
379 keep_on_device: bool = False,
380 ) -> Union[Tuple[np.float32, Union[List[np.float32], np.float32]], Tuple[torch.FloatTensor, Union[List[torch.FloatTensor], torch.FloatTensor]]]:
382 """
383 Calculate the structure function S(Q) for the given atomic structure.
385 Parameters:
386 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
387 keep_on_device (bool): Flag to keep the results on the class device. Default is False.
389 Returns:
390 Tuple of torch tensors containing Q-values and structure function S(Q) if keep_on_device is True, otherwise, numpy arrays on CPU.
391 """
393 # Calculate Scattering S(Q)
394 _, iq = self.iq(structure_path, radii, keep_on_device=True, _total_scattering=True)
396 sq_output = []
397 for i in range(self.num_structures):
398 sq = iq[i]/self.struc_form_avg_sq[i]/self.struc_size[i]
399 sq_output.append(sq)
401 if keep_on_device:
402 if self.num_structures == 1:
403 return self.q.squeeze(-1), sq_output[0]
404 else:
405 return self.q.squeeze(-1), sq_output
406 else:
407 if self.num_structures == 1:
408 return self.q.squeeze(-1).cpu().numpy(), sq_output[0].cpu().numpy()
409 else:
410 return self.q.squeeze(-1).cpu().numpy(), [sq.cpu().numpy() for sq in sq_output]
412 def fq(
413 self,
414 structure_path: Union[str, Atoms, List[Atoms]],
415 radii: Union[List[float], float, None] = None,
416 keep_on_device: bool = False,
417 ) -> Union[Tuple[np.float32, Union[List[np.float32], np.float32]], Tuple[torch.FloatTensor, Union[List[torch.FloatTensor], torch.FloatTensor]]]:
418 """
419 Calculate the reduced structure function F(Q) for the given atomic structure.
421 Parameters:
422 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
423 keep_on_device (bool): Flag to keep the results on the class device. Default is False.
425 Returns:
426 Tuple of torch tensors containing Q-values and reduced structure function F(Q) if keep_on_device is True, otherwise, numpy arrays on CPU.
427 """
428 # Calculate Scattering S(Q)
429 _, iq = self.iq(structure_path, radii, keep_on_device=True, _total_scattering=True)
431 fq_output = []
432 for i in range(self.num_structures):
433 sq = iq[i]/self.struc_form_avg_sq[i]/self.struc_size[i]
434 fq = self.q.squeeze(-1) * sq
435 fq_output.append(fq)
437 if keep_on_device:
438 if self.num_structures == 1:
439 return self.q.squeeze(-1), fq_output[0]
440 else:
441 return self.q.squeeze(-1), fq_output
442 else:
443 if self.num_structures == 1:
444 return self.q.squeeze(-1).cpu().numpy(), fq_output[0].cpu().numpy()
445 else:
446 return self.q.squeeze(-1).cpu().numpy(), [fq.cpu().numpy() for fq in fq_output]
448 def gr(
449 self,
450 structure_path: Union[str, Atoms, List[Atoms]],
451 radii: Union[List[float], float, None] = None,
452 keep_on_device: bool = False,
453 ) -> Union[Tuple[np.float32, Union[List[np.float32], np.float32]], Tuple[torch.FloatTensor, Union[List[torch.FloatTensor], torch.FloatTensor]]]:
455 """
456 Calculate the reduced pair distribution function G(r) for the given atomic structure.
458 Parameters:
459 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
460 keep_on_device (bool): Flag to keep the results on the class device. Default is False.
462 Returns:
463 Tuple of torch tensors containing Q-values and PDF G(r) if keep_on_device is True, otherwise, numpy arrays on CPU.
464 """
465 if self.profile:
466 self.profiler.reset()
468 # Calculate Scattering I(Q), S(Q), F(Q)
469 _, iq = self.iq(structure_path, radii, keep_on_device=True, _total_scattering=True)
471 gr_output = []
472 for i in range(self.num_structures):
473 sq = iq[i]/self.struc_form_avg_sq[i]/self.struc_size[i]
474 if self.profile:
475 self.profiler.time('S(Q)')
476 fq = self.q.squeeze(-1) * sq
477 if self.profile:
478 self.profiler.time('F(Q)')
480 # Calculate total scattering, G(r)
481 damp = 1 if self.qdamp == 0.0 else torch.exp(-(self.r.squeeze(-1) * self.qdamp).pow(2) / 2)
482 lorch_mod = 1 if self.lorch_mod == None else torch.sinc(self.q * self.lorch_mod*(torch.pi / self.qmax))
483 if self.profile:
484 self.profiler.time('Modifications, Qdamp/Lorch')
485 gr = (2 / torch.pi) * torch.sum(fq.unsqueeze(-1) * torch.sin(self.q * self.r.permute(1,0))*self.qstep * lorch_mod, dim=0) * damp
486 if self.profile:
487 self.profiler.time('G(r)')
489 gr_output.append(gr)
491 if keep_on_device:
492 if self.num_structures == 1:
493 return self.r.squeeze(-1), gr_output[0]
494 else:
495 return self.r.squeeze(-1), gr_output
496 else:
497 if self.num_structures == 1:
498 return self.r.squeeze(-1).cpu().numpy(), gr_output[0].cpu().numpy()
499 else:
500 return self.r.squeeze(-1).cpu().numpy(), [gr.cpu().numpy() for gr in gr_output]
502 def _get_all(
503 self,
504 structure_path: Union[str, Atoms, List[Atoms]],
505 radii: Union[List[float], float, None] = None,
506 keep_on_device: bool = False,
507 ) -> Union[Tuple[np.float32,np.float32,Union[List[np.float32], np.float32],Union[List[np.float32],np.float32], Union[List[np.float32],np.float32], Union[List[np.float32], np.float32]],
508 Tuple[torch.FloatTensor,torch.FloatTensor,Union[List[torch.FloatTensor], torch.FloatTensor],Union[List[torch.FloatTensor],torch.FloatTensor], Union[List[torch.FloatTensor],torch.FloatTensor], Union[List[torch.FloatTensor], torch.FloatTensor]]]:
510 """
511 Calculate I(Q), S(Q), F(Q) and G(r) for the given atomic structure and return all.
513 Parameters:
514 structure_path (Union[str, Atoms, List[Atoms]]): Path to the atomic structure file in XYZ/CIF format or stored ASE Atoms objects.
515 keep_on_device (bool): Flag to keep the results on the class device. Default is False.
517 Returns:
518 Tuple of torch tensors containing of r-values, Q-values and Union[List[float_vec], float_vec] of I(Q), S(Q), F(Q) and G(r) if keep_on_device is True, otherwise, numpy arrays on CPU.
519 """
521 # Initialise structure
522 self._initialise_structures(structure_path, radii, disable_pbar = True)
524 # Calculate I(Q) for all initialised structures
525 iq_output = []
526 sq_output = []
527 fq_output = []
528 gr_output = []
529 for i in range(self.num_structures):
531 # Calculate distances and batch
532 if self.batch_size is None:
533 self.batch_size = self._max_batch_size
534 dists = pdist(self.struc_xyz[i]).split(self.batch_size)
535 indices = self.triu_indices[i].split(self.batch_size, dim=1)
536 inverse_indices = self.unique_inverse[i].split(self.batch_size, dim=1)
538 # Calculate scattering using Debye Equation
539 iq = torch.zeros((len(self.q))).to(device=self.device, dtype=torch.float32)
540 for d, inv_idx, idx in zip(dists, inverse_indices, indices):
541 mask = d >= self.rthres
542 occ_product = self.struc_occupancy[i][idx[0]] * self.struc_occupancy[i][idx[1]]
543 sinc = torch.sinc(d[mask] * self.q / torch.pi)
544 ffp = self.struc_unique_form_factors[i][inv_idx[0]] * self.struc_unique_form_factors[i][inv_idx[1]]
545 iq += torch.sum(occ_product.unsqueeze(-1) * ffp * sinc.permute(1,0), dim=0)
547 # Apply Debye-Weller Isotropic Atomic Displacement
548 if self.biso != 0.0:
549 iq *= torch.exp(-self.q.squeeze(-1).pow(2) * self.biso/(8*torch.pi**2))
551 # Calculate S(Q), F(Q) and G(r)
552 sq = iq/self.struc_form_avg_sq[i]/self.struc_size[i]
553 sq_output.append(sq)
555 fq = self.q.squeeze(-1) * sq
556 fq_output.append(fq)
558 damp = 1 if self.qdamp == 0.0 else torch.exp(-(self.r.squeeze(-1) * self.qdamp).pow(2) / 2)
559 lorch_mod = 1 if self.lorch_mod == None else torch.sinc(self.q * self.lorch_mod*(torch.pi / self.qmax))
560 gr = (2 / torch.pi) * torch.sum(fq.unsqueeze(-1) * torch.sin(self.q * self.r.permute(1,0))*self.qstep * lorch_mod, dim=0) * damp
561 gr_output.append(gr)
563 # Self-scattering contribution
564 sinc = torch.ones((self.struc_size[i], len(self.q))).to(device=self.device)
565 iq += torch.sum((self.struc_occupancy[i].unsqueeze(-1) * self.struc_unique_form_factors[i][self.struc_inverse[i]])**2 * sinc, dim=0) / 2
566 iq *= 2
568 iq_output.append(iq)
570 if keep_on_device:
571 if self.num_structures == 1:
572 return self.r.squeeze(-1), self.q.squeeze(-1), iq_output[0], sq_output[0], fq_output[0], gr_output[0]
573 else:
574 return self.r.squeeze(-1), self.q.squeeze(-1), iq_output, sq_output, fq_output, gr_output
575 else:
576 if self.num_structures == 1:
577 return self.r.squeeze(-1).cpu().numpy(), self.q.squeeze(-1).cpu().numpy(), iq_output[0].cpu().numpy(), sq_output[0].cpu().numpy(), fq_output[0].cpu().numpy(), gr_output[0].cpu().numpy()
578 else:
579 return self.r.squeeze(-1).cpu().numpy(), self.q.squeeze(-1).cpu().numpy(), [iq.cpu().numpy() for iq in iq_output], [sq.cpu().numpy() for sq in sq_output], [fq.cpu().numpy() for fq in fq_output], [gr.cpu().numpy() for gr in gr_output]
581 def generate_nanoparticles(
582 self,
583 structure_path: str,
584 radii: Union[List[float], float],
585 sort_atoms: bool = True,
586 disable_pbar: bool = False,
587 _override_device: bool = True,
588 ) -> Tuple[Union[List[Atoms], Atoms], Union[List[float], float]]:
590 """
591 Generate nanoparticles from a given structure and list of radii.
593 Args:
594 structure_path (str): Path to the input structure file.
595 radii (Union[List[float], float]): List of floats or float of radii for nanoparticles to be generated.
596 sort_atoms (bool, optional): Whether to sort atoms in the nanoparticle. Defaults to True.
597 _override_device (bool): Ignore object device and run in CPU
599 Returns:
600 list: List of ASE Atoms objects representing the generated nanoparticles.
601 list: List of nanoparticle sizes (diameter) corresponding to each radius.
602 """
604 # Fix radii type
605 if isinstance(radii, list):
606 radii = [float(r) for r in radii]
607 elif isinstance(radii, float):
608 radii = [radii]
609 elif isinstance(radii, int):
610 radii = [float(radii)]
611 else:
612 raise ValueError('FAILED: Please provide valid radii for generation of nanoparticles')
614 # DEV: Override device
615 device = 'cpu' if _override_device else self.device
617 # Read the input unit cell structure
618 with warnings.catch_warnings():
619 warnings.simplefilter("ignore")
620 unit_cell = read(structure_path)
621 cell_dims = np.array(unit_cell.cell.cellpar()[:3])
622 r_max = np.amax(radii)
624 # Create a supercell to encompass the entire range of nanoparticles and center it
625 supercell_matrix = np.diag((np.ceil(r_max / cell_dims)) * 2 + 2)
626 cell = make_supercell(prim=unit_cell, P=supercell_matrix)
627 cell.center(about=0.)
629 # Convert positions to torch and send to device
630 positions = torch.from_numpy(cell.get_positions()).to(dtype = torch.float32, device = device)
632 # Find all metals and center around the nearest metal
633 ligands = ['O', 'Cl', 'H'] # Placeholder
634 metal_filter = torch.BoolTensor([a not in ligands for a in cell.get_chemical_symbols()]).to(device = device)
635 center_dists = torch.norm(positions, dim=1)
636 positions -= positions[metal_filter][torch.argmin(center_dists[metal_filter])]
637 center_dists = torch.norm(positions, dim=1)
638 min_metal_dist = torch.min(pdist(positions[metal_filter]))
639 min_bond_dist = torch.amin(cdist(positions[metal_filter], positions[~metal_filter]))
640 # Update the cell positions
641 cell.positions = positions.cpu()
643 # Initialize nanoparticle lists and progress bar
644 nanoparticle_list = []
645 nanoparticle_sizes = []
646 pbar = tqdm(desc=f'Generating nanoparticles in range: [{np.amin(radii)},{np.amax(radii)}]', leave=False, total=len(radii), disable=disable_pbar)
648 # Generate nanoparticles for each radius
649 for r in sorted(radii, reverse=True):
651 # Mask all atoms within radius
652 incl_mask = (center_dists <= r) | ((center_dists <= r + min_metal_dist) & ~metal_filter)
654 # Modify objects based on mask
655 cell = cell[incl_mask.cpu()]
656 center_dists = center_dists[incl_mask]
657 metal_filter = metal_filter[incl_mask]
658 positions = positions[incl_mask]
660 # Find interdistances from all included atoms and remove 0's from diagonal
661 interface_dists = cdist(positions, positions).fill_diagonal_(min_metal_dist*2)
663 # Remove floating atoms
664 interaction_mask = torch.amin(interface_dists, dim=0) < min_bond_dist*1.2
666 # Modify objects based on mask
667 cell = cell[interaction_mask.cpu()]
668 center_dists = center_dists[interaction_mask]
669 metal_filter = metal_filter[interaction_mask]
670 positions = positions[interaction_mask]
672 # Determine NP size
673 nanoparticle_size = torch.amax(center_dists) * 2
675 # Sort the atoms
676 if sort_atoms:
677 sorted_cell = ase_sort(cell)
678 if sorted_cell.get_chemical_symbols()[0] in ligands:
679 sorted_cell = sorted_cell[::-1]
681 # Append nanoparticle
682 nanoparticle_list.append(sorted_cell)
683 else:
684 # Append nanoparticle
685 nanoparticle_list.append(cell)
687 # Append size
688 nanoparticle_sizes.append(nanoparticle_size.item())
690 pbar.update(1)
691 pbar.close()
693 return nanoparticle_list, nanoparticle_sizes
695 def _is_notebook(
696 self,
697 ) -> bool:
699 """
700 Checks if the code is running within a Jupyter Notebook, Google Colab, or other interactive environment.
702 Returns:
703 bool: True if running in a Jupyter Notebook or Google Colab, False otherwise.
704 """
705 try:
706 shell = get_ipython().__class__.__name__
707 if shell == 'ZMQInteractiveShell':
708 return True # Jupyter notebook or qtconsole
709 elif shell == 'google.colab._shell':
710 return True # Google Colab
711 elif shell == 'Shell':
712 return True # Apparently also Colab?
713 else:
714 return False # Other cases
715 except NameError:
716 return False # Standard Python Interpreter
718 def interact(
719 self,
720 _cont_updates: bool = False
721 ) -> None:
723 """
724 Initiates an interactive visualization and data analysis tool within a Jupyter Notebook or Google Colab environment.
726 Args:
727 _cont_updates (bool, optional): If True, enables continuous updates for interactive widgets. Defaults to False.
728 """
730 # Check if interaction is valid
731 if not self._is_notebook():
732 print("FAILED: Interactive mode is exlusive to Jupyter Notebook or Google Colab")
733 return
735 # Scattering parameters
736 qmin = self.qmin
737 qmax = self.qmax
738 qstep = self.qstep
739 qdamp = self.qdamp
740 rmin = self.rmin
741 rmax = self.rmax
742 rstep = self.rstep
743 rthres = self.rthres
744 biso = self.biso
745 device = 'cuda' if torch.cuda.is_available() else self.device
746 batch_size = self.batch_size
747 lorch_mod = self.lorch_mod
748 radiation_type = self.radiation_type
749 profile = False
751 with open('display_assets/choose_hardware.png', 'rb') as f:
752 choose_hardware_img = f.read()
753 with open('display_assets/batch_size.png', 'rb') as f:
754 batch_size_img = f.read()
756 """ Utility widgets """
758 # Spacing widget
759 spacing_10px = widgets.Text(description='', layout=Layout(visibility='hidden', height='10px'), disabled=True)
760 spacing_5px = widgets.Text(description='', layout=Layout(visibility='hidden', height='5px'), disabled=True)
762 """ File Selection Tab """
764 # Load diplay display_assets
765 with open('display_assets/enter_path.png', 'rb') as f:
766 enter_path_img = f.read()
767 with open('display_assets/select_files.png', 'rb') as f:
768 select_files_img = f.read()
769 with open('display_assets/radius_a.png', 'rb') as f:
770 radius_a_img = f.read()
771 with open('display_assets/file_1.png', 'rb') as f:
772 file_1_img = f.read()
773 with open('display_assets/file_2.png', 'rb') as f:
774 file_2_img = f.read()
776 # Layout
777 file_tab_layout = Layout(
778 display='flex',
779 flex_flow='column',
780 align_items='stretch',
781 order='solid',
782 width='90%',
783 )
785 # File selection sizes
786 header_widths = [105*1.8, 130*1.8]
787 header_widths = [str(i)+'px' for i in header_widths]
789 # Folder selection
790 folder = widgets.Text(description='', placeholder='Enter data directory', disabled=False, layout=Layout(width='650px'))
792 # Dropdown file sections
793 DEFAULT_MSGS = ['No valid files in entered directory', 'Select data file']
794 select_file_1 = widgets.Dropdown(options=DEFAULT_MSGS, value=DEFAULT_MSGS[0], disabled=True, layout=Layout(width='650px'))
795 select_file_2 = widgets.Dropdown(options=DEFAULT_MSGS, value=DEFAULT_MSGS[0], disabled=True, layout=Layout(width='650px'))
797 # File 1
798 select_file_desc_1 = HBox([widgets.Image(value=file_1_img, format='png', layout=Layout(object_fit='contain', object_position='20px', width='32px'))], layout=Layout(width='88px'))
799 select_radius_desc_1 = HBox([widgets.Image(value=radius_a_img, format='png', layout=Layout(object_fit='contain', object_position='20px', width='60px', visibility='hidden'))], layout=Layout(width='88px'))
800 select_radius_1 = widgets.FloatText(min = 0, max = 50, step=0.01, value=5, disabled = False, layout = Layout(width='50px', visibility='hidden'))
801 cif_text_1 = widgets.Text(
802 value='Given radius, generate spherical nanoparticles (NP) from crystallographic information files (CIFs)',
803 disabled=True,
804 layout=Layout(width='595px', visibility='hidden')
805 )
807 # File 2
808 select_file_desc_2 = HBox([widgets.Image(value=file_2_img, format='png', layout=Layout(object_fit='contain', object_position='20px', width='32px'))], layout=Layout(width='88px'))
809 select_radius_desc_2 = HBox([widgets.Image(value=radius_a_img, format='png', layout=Layout(object_fit='contain', object_position='20px', width='60px', visibility='hidden'))], layout=Layout(width='88px'))
810 select_radius_2 = widgets.FloatText(min = 0, max = 50, step=0.01, value=5, disabled = False, layout = Layout(width='50px', visibility='hidden'))
811 cif_text_2 = widgets.Text(
812 value='Given radius, generate spherical nanoparticles (NP) from crystallographic information files (CIFs)',
813 disabled=True,
814 layout=Layout(width='595px', visibility='hidden')
815 )
817 # File selection Tab
818 file_tab = VBox([
819 # Enter path
820 widgets.Image(value=enter_path_img, format='png', layout=Layout(object_fit='contain', width=header_widths[0])),
821 folder,
823 spacing_10px,
825 # Select file(s)
826 widgets.Image(value=select_files_img, format='png', layout=Layout(object_fit='contain', width=header_widths[1])),
828 # Select file 1
829 HBox([select_file_desc_1, select_file_1]),
831 # if CIF, radius options
832 HBox([select_radius_desc_1, select_radius_1, cif_text_1]),
834 spacing_10px,
836 # Select file 2
837 HBox([select_file_desc_2, select_file_2]),
839 # If CIF, radius options
840 HBox([select_radius_desc_2, select_radius_2, cif_text_2]),
841 ], layout = file_tab_layout)
843 """ Scattering Options Tab """
845 # Load display_assets
846 with open('display_assets/qslider.png', 'rb') as f:
847 qslider_img = f.read()
848 with open('display_assets/rslider.png', 'rb') as f:
849 rslider_img = f.read()
850 with open('display_assets/qdamp.png', 'rb') as f:
851 qdamp_img = f.read()
852 with open('display_assets/global_biso.png', 'rb') as f:
853 global_biso_img = f.read()
854 with open('display_assets/a.png', 'rb') as f:
855 a_img = f.read()
856 with open('display_assets/a_inv.png', 'rb') as f:
857 a_inv_img = f.read()
858 with open('display_assets/a_sq.png', 'rb') as f:
859 a_sq_img = f.read()
860 with open('display_assets/qstep.png', 'rb') as f:
861 qstep_img = f.read()
862 with open('display_assets/rstep.png', 'rb') as f:
863 rstep_img = f.read()
864 with open('display_assets/rthres.png', 'rb') as f:
865 rthres_img = f.read()
866 with open('display_assets/radiation_type.png', 'rb') as f:
867 radiation_type_img = f.read()
868 with open('display_assets/scattering_parameters.png', 'rb') as f:
869 scattering_parameters_img = f.read()
870 with open('display_assets/presets.png', 'rb') as f:
871 presets_img = f.read()
873 # Radiation
874 radtype_button = widgets.ToggleButtons(
875 options=['xray', 'neutron'],
876 value=radiation_type,
877 layout = Layout(width='800px'),
878 button_style='primary'
879 )
881 # Q value slider
882 qslider = widgets.FloatRangeSlider(
883 value=[qmin, qmax],
884 min=0.0, max=50.0, step=0.01,
885 orientation='horizontal',
886 readout=True,
887 style={'font_weight':'bold', 'slider_color': 'white', 'description_width': '100px'},
888 layout = widgets.Layout(width='80%'),
889 )
891 # r value slider
892 rslider = widgets.FloatRangeSlider(
893 value=[rmin, rmax],
894 min=0, max=100.0, step=rstep,
895 orientation='horizontal',
896 readout=True,
897 style={'font_weight':'bold', 'slider_color': 'white', 'description_width': '100px'},
898 layout = widgets.Layout(width='80%'),
899 )
901 # Qdamp box
902 qdamp_box = widgets.FloatText(
903 min=0.00,max=0.10, step=0.01,
904 value=qdamp,
905 layout = widgets.Layout(width='50px'),
906 )
908 # B iso box
909 biso_box = widgets.FloatText(
910 min=0.00, max=1.00, step=0.01,
911 value=biso,
912 layout = widgets.Layout(width='50px'),
913 )
915 # Qstep box
916 qstep_box = widgets.FloatText(
917 min = 0.001, max = 1, step=0.001,
918 value=qstep,
919 layout=Layout(width='50px'),
920 )
922 # rstep box
923 rstep_box = widgets.FloatText(
924 min = 0.001, max = 1, step=0.001,
925 value=rstep,
926 layout=Layout(width='50px'),
927 )
929 # rthreshold box
930 rthres_box = widgets.FloatText(
931 min = 0.001, max = 1, step=0.001,
932 value=rthres,
933 layout=Layout(width='50px'),
934 )
936 # Lorch modification button
937 lorch_mod_button = widgets.ToggleButton(
938 value=lorch_mod,
939 description='Lorch modification (OFF)',
940 layout=Layout(width='250px'),
941 button_style='primary',
942 )
944 # SAS preset button
945 sas_preset_button = widgets.Button(
946 description = 'Small Angle Scattering preset',
947 layout=Layout(width='300px'),
948 button_style='primary',
949 )
951 # Powder diffraction preset
952 pd_preset_button = widgets.Button(
953 description = 'Powder Diffraction preset',
954 layout=Layout(width='300px'),
955 button_style='primary',
956 )
958 # Total scattering preset
959 ts_preset_button = widgets.Button(
960 description = 'Total Scattering preset',
961 layout=Layout(width='300px'),
962 button_style='primary',
963 )
965 # Total scattering preset
966 reset_button = widgets.Button(
967 description = 'Reset scattering options',
968 layout=Layout(width='300px'),
969 button_style='danger',
970 )
972 # Scattering Tab sizes
973 header_widths = [90*1.15, 135*1.15, 110*1.15]
974 header_widths = [str(i)+'px' for i in header_widths]
975 a_inv_width = '27px'
976 a_width = '12px'
977 a_sq_width = '19px'
979 # Scattering tab
980 scattering_tab = VBox([
981 # Radiation button
982 widgets.Image(value=radiation_type_img, format='png', layout=Layout(object_fit='contain', width=header_widths[0])),
983 radtype_button,
985 spacing_10px,
987 # Scattering parameters
988 widgets.Image(value=scattering_parameters_img, format='png', layout=Layout(object_fit='contain', width=header_widths[1])),
990 # Q slider
991 HBox([
992 # Q slider img
993 HBox([widgets.Image(value=qslider_img, format='png', layout=Layout(object_fit='contain', width='120px'))], layout=Layout(width='150px')),
994 # Q slider
995 qslider,
996 # Unit
997 #¤HBox([widgets.Image(value=a_inv_img, format='png', layout=Layout(object_fit='contain', width=a_inv_width))], layout=Layout(width='50px')),
998 ]),
1000 spacing_5px,
1002 # r slider
1003 HBox([
1004 # r slider img
1005 HBox([widgets.Image(value=rslider_img, format='png', layout=Layout(object_fit='contain', width='110px'))], layout=Layout(width='150px')),
1006 # r slider
1007 rslider,
1008 # Unit
1009 #HBox([widgets.Image(value=a_img, format='png', layout=Layout(object_fit='contain', width=a_width))], layout=Layout(width='50px')),
1010 ]),
1012 spacing_5px,
1014 # Other
1015 HBox([
1016 # Qstep img
1017 HBox([widgets.Image(value=qstep_img, format='png', layout=Layout(object_fit='contain', object_position='', width='65px'))], layout=Layout(width='75px')),
1018 # Qstep box
1019 qstep_box,
1020 # Unit
1021 #HBox([widgets.Image(value=a_inv_img, format='png', layout=Layout(object_fit='contain', width=a_inv_width))], layout=Layout(width='50px')),
1023 # r step img
1024 widgets.Text(description='', layout=Layout(visibility='hidden', width='60px'), disabled=True),
1025 HBox([widgets.Image(value=rstep_img, format='png', layout=Layout(object_fit='contain', object_position='', width='55px'))], layout=Layout(width='65px')),
1026 # r step box
1027 rstep_box,
1028 # Unit
1029 #HBox([widgets.Image(value=a_img, format='png', layout=Layout(object_fit='contain', width=a_width))], layout=Layout(width='50px')),
1031 # Q damp img
1032 widgets.Text(description='', layout=Layout(visibility='hidden', width='60px'), disabled=True),
1033 HBox([widgets.Image(value=qdamp_img, format='png', layout=Layout(object_fit='contain', object_position='', width='75px'))], layout=Layout(width='85px')),
1034 # Q damp box
1035 qdamp_box,
1036 # Unit
1037 #HBox([widgets.Image(value=a_inv_img, format='png', layout=Layout(object_fit='contain', width=a_inv_width))], layout=Layout(width='80px')),
1039 # r thres img
1040 widgets.Text(description='', layout=Layout(visibility='hidden', width='60px'), disabled=True),
1041 HBox([widgets.Image(value=rthres_img, format='png', layout=Layout(object_fit='contain', object_position='', width='55px'))], layout=Layout(width='65px')),
1042 # r thres
1043 rthres_box,
1044 # Unit
1045 #HBox([widgets.Image(value=a_img, format='png', layout=Layout(object_fit='contain', width=a_width))], layout=Layout(width='50px')),
1047 # Global B iso img
1048 widgets.Text(description='', layout=Layout(visibility='hidden', width='60px'), disabled=True),
1049 HBox([widgets.Image(value=global_biso_img, format='png', layout=Layout(object_fit='contain', object_position='', width='95px'))], layout=Layout(width='105px')),
1050 # Global B iso box
1051 biso_box,
1052 # Unit
1053 #HBox([widgets.Image(value=a_sq_img, format='png', layout=Layout(object_fit='contain', width=a_sq_width))], layout=Layout(width='50px')),
1054 ]),
1056 spacing_5px,
1058 # Global B iso
1059 HBox([
1060 # Lorch mod button
1061 lorch_mod_button,
1062 # Unit
1063 HBox([widgets.Image(value=a_img, format='png', layout=Layout(object_fit='contain', width=a_width, visibility='hidden'))], layout=Layout(width='50px')),
1064 ]),
1066 spacing_10px,
1068 # Presets
1069 widgets.Image(value=presets_img, format='png', layout=Layout(object_fit='contain', width=header_widths[2])),
1070 HBox([sas_preset_button, pd_preset_button, ts_preset_button, reset_button]),
1071 ])
1073 """ Plotting Options """
1075 # Load display display_assets
1076 with open('display_assets/iq_scaling.png', 'rb') as f:
1077 iq_scaling_img = f.read()
1078 with open('display_assets/show_hide.png', 'rb') as f:
1079 show_hide_img = f.read()
1080 with open('display_assets/max_norm.png', 'rb') as f:
1081 max_norm_img = f.read()
1082 with open('display_assets/iq.png', 'rb') as f:
1083 iq_img = f.read()
1084 with open('display_assets/sq.png', 'rb') as f:
1085 sq_img = f.read()
1086 with open('display_assets/fq.png', 'rb') as f:
1087 fq_img = f.read()
1088 with open('display_assets/gr.png', 'rb') as f:
1089 gr_img = f.read()
1091 # Y-axis I(Q) scale button
1092 scale_type_button = widgets.ToggleButtons( options=['linear', 'logarithmic'], value='linear', button_style='primary', layout=Layout(width='600'))
1094 # Show/Hide buttons
1095 show_iq_button = widgets.Checkbox(value = True)
1096 show_sq_button = widgets.Checkbox(value = True)
1097 show_fq_button = widgets.Checkbox(value = True)
1098 show_gr_button = widgets.Checkbox(value = True)
1100 # Max normalize buttons
1101 normalize_iq = widgets.Checkbox(value = False)
1102 normalize_sq = widgets.Checkbox(value = False)
1103 normalize_fq = widgets.Checkbox(value = False)
1104 normalize_gr = widgets.Checkbox(value = False)
1106 # Plotting tab sizes
1107 function_offset = '-90px 3px'
1108 function_size = 35
1109 header_scale = 0.95
1110 header_widths = [130, 120, 147]
1111 header_widths = [str(i * header_scale)+'px' for i in header_widths]
1113 # Plotting tab
1114 plotting_tab = VBox([
1115 # I(Q) scaling img
1116 widgets.Image(value=iq_scaling_img, format='png', layout=Layout(object_fit='contain', width=header_widths[0])),
1117 scale_type_button,
1119 spacing_10px,
1121 # Show / Hide img
1122 widgets.Image(value=show_hide_img, format='png', layout=Layout(object_fit='contain', width=header_widths[1])),
1124 # Options
1125 HBox([
1126 HBox([show_iq_button, widgets.Image(value=iq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1127 HBox([show_sq_button, widgets.Image(value=sq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1128 HBox([show_fq_button, widgets.Image(value=fq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1129 HBox([show_gr_button, widgets.Image(value=gr_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1130 ]),
1132 spacing_10px,
1134 # Max normalization img
1135 widgets.Image(value=max_norm_img, format='png', layout=Layout(object_fit='contain', width=header_widths[2])),
1137 # Options
1138 HBox([
1139 HBox([normalize_iq, widgets.Image(value=iq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1140 HBox([normalize_sq, widgets.Image(value=sq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1141 HBox([normalize_fq, widgets.Image(value=fq_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1142 HBox([normalize_gr, widgets.Image(value=gr_img, format='png', width=function_size, layout=Layout(object_fit='contain', object_position=function_offset))]),
1143 ]),
1144 ])
1147 """ Hardware Options Tab """
1149 # Hardware button
1150 hardware_button = widgets.ToggleButtons(options=['cpu', 'cuda'], value=device, button_style='primary')
1152 # Distance batch-size box
1153 batch_size_box = widgets.IntText(min = 100, max = 10000, value=batch_size)
1155 # Hardware tab sizes
1156 header_scale = 1
1157 header_widths = [120, 175]
1158 header_widths = [str(i * header_scale)+'px' for i in header_widths]
1160 # Hardware tab
1161 hardware_tab = VBox([
1162 # Choose hardware img
1163 widgets.Image(value=choose_hardware_img, format='png', layout=Layout(object_fit='contain', width=header_widths[0])),
1165 # Hardware box
1166 hardware_button,
1168 spacing_10px,
1170 # Distance batch_size img
1171 widgets.Image(value=batch_size_img, format='png', layout=Layout(object_fit='contain', width=header_widths[1])),
1173 # Distance batch size box
1174 batch_size_box,
1175 ])
1178 """ Display tabs """
1180 # Display Tabs
1181 tabs = widgets.Tab([
1182 file_tab,
1183 scattering_tab,
1184 plotting_tab,
1185 hardware_tab,
1186 ])
1188 # Set tab titles
1189 tabs.set_title(0, 'File Selection')
1190 tabs.set_title(1, 'Scattering Options')
1191 tabs.set_title(2, 'Plotting Options')
1192 tabs.set_title(3, 'Hardware Options')
1194 # Plot button and Download buttons
1195 plot_button = widgets.Button(description='Plot data', layout=Layout(width='50%', height='90%'), button_style='primary', icon='fa-pencil-square-o')
1196 download_button = widgets.Button(description="Download- and plot data", layout=Layout(width='50%', height='90%'), button_style='success', icon='fa-download')
1198 def display_tabs():
1199 display(VBox([tabs, HBox([plot_button, download_button], layout=Layout(width='100%', height='50px'))]))
1202 """ Download utility """
1204 # Download options
1205 def create_download_link(select_file, select_radius, filename_prefix, data, header=None):
1207 # Collect Metadata
1208 metadata ={
1209 'qmin': qslider.value[0],
1210 'qmax': qslider.value[1],
1211 'qdamp': qdamp_box.value,
1212 'qstep': qstep_box.value,
1213 'rmin': rslider.value[0],
1214 'rmax': rslider.value[1],
1215 'rstep': rstep_box.value,
1216 'rthres': rthres_box.value,
1217 'biso': biso_box.value,
1218 'device': hardware_button.value,
1219 'batch_size': batch_size_box.value,
1220 'lorch_mod': lorch_mod_button.value,
1221 'radiation_type': radtype_button.value
1222 }
1224 # Join content
1225 output = ''
1226 content = "\n".join([",".join(map(str, np.around(row,len(str(qstep_box.value))))) for row in data])
1227 for k,v in metadata.items():
1228 output += f'{k}:{v}\n'
1229 output += '\n'
1230 if header:
1231 output += header + '\n'
1232 output += content
1234 # Encode as base64
1235 b64 = base64.b64encode(output.encode()).decode()
1237 # Add Time
1238 t = datetime.now()
1239 year = f'{t.year}'[-2:]
1240 month = f'{t.month}'.zfill(2)
1241 day = f'{t.day}'.zfill(2)
1242 hours = f'{t.hour}'.zfill(2)
1243 minutes = f'{t.minute}'.zfill(2)
1244 seconds = f'{t.second}'.zfill(2)
1246 # Make filename
1247 if select_radius is not None:
1248 filename = filename_prefix + '_' + select_file.value.split('/')[-1].split('.')[0] + '_radius' + str(select_radius.value) + '_' + month + day + year + '_' + hours + minutes + seconds + '.csv'
1249 else:
1250 filename = filename_prefix + '_' + select_file.value.split('/')[-1].split('.')[0] + '_' + month + day + year + '_' + hours + minutes + seconds + '.csv'
1252 # Make href and return
1253 href = filename_prefix + ':\t' + f'<a href="data:text/csv;base64,{b64}" download="{filename}">{filename}</a>'
1254 return href
1256 def create_structure_download_link(select_file, select_radius, filename_prefix, ase_atoms):
1258 # Get atomic properties
1259 positions = ase_atoms.get_positions()
1260 elements = ase_atoms.get_chemical_symbols()
1261 num_atoms = len(ase_atoms)
1263 # Make header
1264 header = str(num_atoms) + "\n\n"
1266 # Join content
1267 content = header + "\n".join([el + '\t' + "\t".join(map(str,np.around(row, 3))) for row, el in zip(positions, elements)])
1269 # Encode as base64
1270 b64 = base64.b64encode(content.encode()).decode()
1272 # Add Time
1273 t = datetime.now()
1274 year = f'{t.year}'[-2:]
1275 month = f'{t.month}'.zfill(2)
1276 day = f'{t.day}'.zfill(2)
1277 hours = f'{t.hour}'.zfill(2)
1278 minutes = f'{t.minute}'.zfill(2)
1279 seconds = f'{t.second}'.zfill(2)
1281 # Make ilename
1282 filename = filename_prefix + '_' + select_file.value.split('/')[-1].split('.')[0] + '_radius' + str(select_radius.value) + '_' + month + day + year + '_' + hours + minutes + seconds + '.xyz'
1284 # Make href and return
1285 href = filename_prefix + ':\t' + f'<a href="data:text/xyz;base64,{b64}" download="{filename}">{filename}</a>'
1286 return href
1289 @download_button.on_click
1290 def on_download_button_click(button):
1291 global debye_outputs
1292 # Try to compile all the data and create html link to download files
1293 try:
1294 # clear and display
1295 clear_output(wait=True)
1296 display_tabs()
1298 debye_outputs = []
1299 for select_file, select_radius in zip([select_file_1, select_file_2], [select_radius_1, select_radius_2]):
1300 try:
1301 path_ext = select_file.value.split('.')[-1]
1302 except Exception as e:
1303 return
1304 if (select_file.value is not None) and (select_file.value not in DEFAULT_MSGS) and (path_ext in ['xyz', 'cif']):
1305 try:
1306 debye_calc = DebyeCalculator(
1307 device=hardware_button.value,
1308 batch_size=batch_size_box.value,
1309 radiation_type=radtype_button.value,
1310 qmin=qslider.value[0],
1311 qmax=qslider.value[1],
1312 qstep=qstep_box.value,
1313 qdamp=qdamp_box.value,
1314 rmin=rslider.value[0],
1315 rmax=rslider.value[1],
1316 rstep=rstep_box.value,
1317 rthres=rthres_box.value,
1318 biso=biso_box.value,
1319 lorch_mod=lorch_mod_button.value
1320 )
1321 if (select_radius.layout.visibility != 'hidden') and (select_radius.value > 8):
1322 print(f'Generating nanoparticle of radius {select_radius.value} using {select_file.value.split("/")[-1]} ...')
1323 debye_outputs.append(debye_calc._get_all(select_file.value, select_radius.value))
1324 except Exception as e:
1325 print(f'FAILED: Could not load data file: {path}', end='\r')
1327 if len(debye_outputs) < 1:
1328 print('FAILED: Please select data file(s)', end="\r")
1329 return
1331 i = 0
1332 for select_file, select_radius in zip([select_file_1, select_file_2], [select_radius_1, select_radius_2]):
1334 # Display download links
1335 if select_file.value not in DEFAULT_MSGS:
1337 # Print
1338 print('Download links for ' + select_file.value.split('/')[-1] + ':')
1340 r, q, iq, sq, fq, gr = debye_outputs[i]
1342 iq_data = np.column_stack([q, iq])
1343 sq_data = np.column_stack([q, sq])
1344 fq_data = np.column_stack([q, fq])
1345 gr_data = np.column_stack([r, gr])
1347 if select_radius.layout.visibility == 'visible':
1348 ase_atoms, _ = DebyeCalculator().generate_nanoparticles(select_file.value, select_radius.value)
1349 display(HTML(create_structure_download_link(select_file, select_radius, f'structure', ase_atoms[0])))
1350 display(HTML(create_download_link(select_file, select_radius, 'iq', iq_data, "q,I(Q)")))
1351 display(HTML(create_download_link(select_file, select_radius, 'sq', sq_data, "q,S(Q)")))
1352 display(HTML(create_download_link(select_file, select_radius, 'fq', fq_data, "q,F(Q)")))
1353 display(HTML(create_download_link(select_file, select_radius, 'gr', gr_data, "r,G(r)")))
1354 else:
1355 display(HTML(create_download_link(select_file, None, 'iq', iq_data, "q,I(Q)")))
1356 display(HTML(create_download_link(select_file, None, 'sq', sq_data, "q,S(Q)")))
1357 display(HTML(create_download_link(select_file, None, 'fq', fq_data, "q,F(Q)")))
1358 display(HTML(create_download_link(select_file, None, 'gr', gr_data, "r,G(r)")))
1359 print('\n')
1360 i += 1
1362 update_figure(debye_outputs)
1364 except Exception as e:
1365 raise(e)
1366 print('FAILED: Please select data file(s)', end="\r")
1368 """ Observer utility """
1370 # Define a function to update the scattering patterns based on the selected parameters
1371 def update_options(change):
1372 folder = change.new
1373 paths = sorted(glob(os.path.join(folder, '*.xyz')) + glob(os.path.join(folder, '*.cif')))
1374 if len(paths):
1375 for select_file in [select_file_1, select_file_2]:
1376 select_file.options = ['Select data file'] + paths
1377 select_file.value = 'Select data file'
1378 select_file.disabled = False
1379 else:
1380 for select_file in [select_file_1, select_file_2]:
1381 select_file.options = [DEFAULT_MSGS[0]]
1382 select_file.value = DEFAULT_MSGS[0]
1383 select_file.disabled = True
1386 def update_options_radius_1(change):
1387 #select_radius = change.new
1388 selected_ext = select_file_1.value.split('.')[-1]
1389 if selected_ext == 'xyz':
1390 select_radius_desc_1.children[0].layout.visibility = 'hidden'
1391 select_radius_1.layout.visibility = 'hidden'
1392 cif_text_1.layout.visibility = 'hidden'
1393 elif selected_ext == 'cif':
1394 select_radius_desc_1.children[0].layout.visibility = 'visible'
1395 select_radius_1.layout.visibility = 'visible'
1396 cif_text_1.layout.visibility = 'visible'
1397 else:
1398 select_radius_desc_1.children[0].layout.visibility = 'hidden'
1399 select_radius_1.layout.visibility = 'hidden'
1400 cif_text_1.layout.visibility = 'hidden'
1402 def update_options_radius_2(change):
1403 #select_radius = change.new
1404 selected_ext = select_file_2.value.split('.')[-1]
1405 if selected_ext == 'xyz':
1406 select_radius_desc_2.children[0].layout.visibility = 'hidden'
1407 select_radius_2.layout.visibility = 'hidden'
1408 cif_text_2.layout.visibility = 'hidden'
1409 elif selected_ext == 'cif':
1410 select_radius_desc_2.children[0].layout.visibility = 'visible'
1411 select_radius_2.layout.visibility = 'visible'
1412 cif_text_2.layout.visibility = 'visible'
1413 else:
1414 select_radius_desc_2.children[0].layout.visibility = 'hidden'
1415 select_radius_2.layout.visibility = 'hidden'
1416 cif_text_2.layout.visibility = 'hidden'
1418 # Link the update functions to the dropdown widget's value change event
1419 folder.observe(update_options, names='value')
1420 select_file_1.observe(update_options_radius_1, names='value')
1421 select_file_2.observe(update_options_radius_2, names='value')
1424 """ Plotting utility """
1426 def togglelorch(change):
1427 if change['new']:
1428 lorch_mod_button.description = 'Lorch modification (ON)'
1429 else:
1430 lorch_mod_button.description = 'Lorch modification (OFF)'
1432 lorch_mod_button.observe(togglelorch, 'value')
1434 @sas_preset_button.on_click
1435 def sas_preset(b=None):
1436 # Change scale type
1437 scale_type_button.value = 'logarithmic'
1439 # Hide all but IQ
1440 show_iq_button.value = True
1441 show_fq_button.value = False
1442 show_sq_button.value = False
1443 show_gr_button.value = False
1445 # Set qmin and qmax
1446 qslider.value = [0.0, 3.0]
1447 qstep_box.value = 0.01
1449 @pd_preset_button.on_click
1450 def pd_preset(b=None):
1451 # Change scale type
1452 scale_type_button.value = 'linear'
1454 # Hide all but IQ
1455 show_iq_button.value = True
1456 show_fq_button.value = False
1457 show_sq_button.value = False
1458 show_gr_button.value = False
1460 # Set qmin and qmax
1461 qslider.value = [1.0, 8.0]
1462 qstep_box.value = 0.1
1464 @ts_preset_button.on_click
1465 def ts_preset(b=None):
1466 # Change scale type
1467 scale_type_button.value = 'linear'
1469 # Hide all but IQ
1470 show_iq_button.value = True
1471 show_fq_button.value = True
1472 show_sq_button.value = False
1473 show_gr_button.value = True
1475 # Set qmin and qmax
1476 qslider.value = [1.0, 30.0]
1477 qstep_box.value = 0.1
1479 @reset_button.on_click
1480 def reset(b=None):
1481 # Change scale type
1482 scale_type_button.value = 'linear'
1484 # Hide all but IQ
1485 show_iq_button.value = True
1486 show_fq_button.value = True
1487 show_sq_button.value = True
1488 show_gr_button.value = True
1490 # Set qmin and qmax
1491 qslider.value = [1.0, 30.0]
1492 rslider.value = [0.0, 20.0]
1493 qstep_box.value = 0.1
1494 rstep_box.value = 0.01
1495 biso_box.value = 0.3
1496 qdamp_box.value = 0.04
1497 rthres_box.value = 0.0
1499 def update_figure(debye_outputs, _unity_sq=True):
1501 xseries, yseries = [], []
1502 xlabels, ylabels = [], []
1503 scales, titles = [], []
1504 axis_ids = []
1506 normalize_iq_text = ' [counts]' if not normalize_iq.value else ' [normalized]'
1507 normalize_sq_text = '' if not normalize_iq.value else ' [normalized]'
1508 normalize_fq_text = '' if not normalize_iq.value else ' [normalized]'
1509 normalize_gr_text = '' if not normalize_iq.value else ' [normalized]'
1511 for do in debye_outputs:
1512 if show_iq_button.value:
1513 axis_ids.append(0)
1514 xseries.append(do[1]) # q
1515 iq_ = do[2] if not normalize_iq.value else do[2]/max(do[2])
1516 yseries.append(iq_) # iq
1517 xlabels.append('$Q$ [$\AA^{-1}$]')
1518 ylabels.append('$I(Q)$' + normalize_iq_text)
1519 if scale_type_button.value == 'logarithmic':
1520 scales.append('log')
1521 else:
1522 scales.append('linear')
1523 scale = scale_type_button.value
1524 titles.append('Scattering Intensity, I(Q)')
1525 if show_sq_button.value:
1526 axis_ids.append(1)
1527 xseries.append(do[1]) # q
1528 sq_ = do[3] if not normalize_sq.value else do[3]/max(do[3])
1529 yseries.append(sq_) # sq
1530 xlabels.append('$Q$ [$\AA^{-1}$]')
1531 ylabels.append('$S(Q)$' + normalize_sq_text)
1532 scales.append('linear')
1533 titles.append('Structure Function, S(Q)')
1534 if show_fq_button.value:
1535 axis_ids.append(2)
1536 xseries.append(do[1]) # q
1537 fq_ = do[4] if not normalize_fq.value else do[4]/max(do[4])
1538 yseries.append(fq_) # fq
1539 xlabels.append('$Q$ [$\AA^{-1}$]')
1540 ylabels.append('$F(Q)$'+ normalize_fq_text)
1541 scales.append('linear')
1542 titles.append('Reduced Structure Function, F(Q)')
1543 if show_gr_button.value:
1544 axis_ids.append(3)
1545 xseries.append(do[0]) # r
1546 gr_ = do[5] if not normalize_gr.value else do[5]/max(do[5])
1547 yseries.append(gr_) # gr
1548 xlabels.append('$r$ [$\AA$]')
1549 ylabels.append('$G(r)$' + normalize_gr_text)
1550 scales.append('linear')
1551 titles.append('Reduced Pair Distribution Function, G(r)')
1553 sup_title = []
1554 labels = []
1555 if select_file_1.value not in ['Select data file', 'No valid files in entered directory']:
1557 sup_title.append(select_file_1.value.split('/')[-1])
1559 if select_radius_1.layout.visibility == 'hidden':
1560 labels.append(sup_title[-1])
1561 else:
1562 labels.append(sup_title[-1] + ', rad.: ' + str(select_radius_1.value) + ' Å')
1564 if select_file_2.value not in ['Select data file', 'No valid files in entered directory']:
1566 sup_title.append(select_file_2.value.split('/')[-1])
1568 if select_radius_2.layout.visibility == 'hidden':
1569 labels.append(sup_title[-1])
1570 else:
1571 labels.append(sup_title[-1] + ', rad.: ' + str(select_radius_2.value) + ' Å')
1573 if len(labels) == 0:
1574 return
1576 num_plots = int(show_iq_button.value) + int(show_sq_button.value) + int(show_fq_button.value) + int(show_gr_button.value)
1577 if num_plots == 4:
1578 fig, axs = plt.subplots(2,2,figsize=(12, 8), dpi=75)
1579 axs = axs.ravel()
1580 elif num_plots == 3:
1581 fig, axs = plt.subplots(3,1,figsize=(12,8), dpi=75)
1582 elif num_plots == 2:
1583 fig, axs = plt.subplots(2,1,figsize=(12,8), dpi=75)
1584 elif num_plots == 1:
1585 fig, axs = plt.subplots(figsize=(12,6), dpi=75)
1586 axs = [axs]
1587 else:
1588 return
1590 for i,(x,y,xl,yl,s,t,l) in enumerate(zip(xseries, yseries, xlabels, ylabels, scales, titles, np.repeat(labels, num_plots))):
1592 ii = i % num_plots
1593 axs[ii].set_xscale(s)
1594 axs[ii].set_yscale(s)
1595 axs[ii].plot(x,y, label=l)
1596 axs[ii].set(xlabel=xl, ylabel=yl, title=t)
1597 axs[ii].relim()
1598 axs[ii].autoscale_view()
1599 axs[ii].grid(alpha=0.2, which='both')
1600 axs[ii].legend()
1602 if len(sup_title) == 1:
1603 title = f"Showing files: {sup_title[0]}"
1604 else:
1605 title = f"Showing files: {sup_title[0]} and {sup_title[1]}"
1606 fig.suptitle(title)
1607 fig.tight_layout()
1609 @plot_button.on_click
1610 def update_parameters(b=None):
1611 global debye_outputs
1613 debye_outputs = []
1614 for select_file, select_radius in zip([select_file_1, select_file_2], [select_radius_1, select_radius_2]):
1615 try:
1616 path_ext = select_file.value.split('.')[-1]
1617 except Exception as e:
1618 return
1619 if (select_file.value is not None) and (select_file.value not in [DEFAULT_MSGS]) and (path_ext in ['xyz', 'cif']):
1620 try:
1621 # TODO if not changed, dont make new object
1622 debye_calc = DebyeCalculator(
1623 device=hardware_button.value,
1624 batch_size=batch_size_box.value,
1625 radiation_type=radtype_button.value,
1626 qmin=qslider.value[0],
1627 qmax=qslider.value[1],
1628 qstep=qstep_box.value,
1629 qdamp=qdamp_box.value,
1630 rmin=rslider.value[0],
1631 rmax=rslider.value[1],
1632 rstep=rstep_box.value,
1633 rthres=rthres_box.value,
1634 biso=biso_box.value,
1635 lorch_mod=lorch_mod_button.value
1636 )
1637 if not select_radius.disabled and select_radius.value > 8:
1638 print(f'Generating nanoparticle of radius {select_radius.value} using {select_file.value.split("/")[-1]} ...')
1639 debye_outputs.append(debye_calc._get_all(select_file.value, select_radius.value))
1640 except Exception as e:
1641 print(f'FAILED: Could not load data file: {path}', end='\r')
1643 # Clear and display
1644 clear_output(wait=True)
1645 display_tabs()
1647 if len(debye_outputs) < 1:
1648 print('FAILED: Please select data file(s)', end="\r")
1649 return
1651 update_figure(debye_outputs)
1653 # Display tabs when function is called
1654 display_tabs()