Coverage for /Users/Newville/Codes/xraylarch/larch/utils/jsonutils.py: 56%
224 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-09 10:08 -0600
1#!/usr/bin/env python
2"""
3 json utilities for larch objects
4"""
5import json
6import io
7import numpy as np
8import h5py
9from datetime import datetime
10from collections import namedtuple
11import logging
13HAS_STATE = {}
14try:
15 from sklearn.cross_decomposition import PLSRegression
16 from sklearn.linear_model import LassoLarsCV, LassoLars, Lasso
17 HAS_STATE.update({'PLSRegression': PLSRegression,
18 'LassoLarsCV':LassoLarsCV,
19 'LassoLars': LassoLars, 'Lasso': Lasso})
21except ImportError:
22 pass
24from lmfit import Parameter, Parameters
25from lmfit.model import Model, ModelResult
26from lmfit.minimizer import Minimizer, MinimizerResult
27from lmfit.parameter import SCIPY_FUNCTIONS
29from larch import Group, isgroup, Journal, ParameterGroup
31from larch.xafs import FeffitDataSet, FeffDatFile, FeffPathGroup, TransformGroup
32from larch.xafs.feffutils import FeffCalcResults
33from larch.utils.strutils import bytes2str, str2bytes, fix_varname
34from larch.utils.logging import getLogger
35from larch.utils.logging import _levels as LoggingLevels
37HAS_STATE['FeffCalcResults'] = FeffCalcResults
38HAS_STATE['FeffDatFile'] = FeffDatFile
39HAS_STATE['FeffPathGroup'] = FeffPathGroup
40HAS_STATE['Journal'] = Journal
42LarchGroupTypes = {'Group': Group,
43 'ParameterGroup': ParameterGroup,
44 'FeffitDataSet': FeffitDataSet,
45 'TransformGroup': TransformGroup,
46 'MinimizerResult': MinimizerResult,
47 'FeffDatFile': FeffDatFile,
48 'FeffPathGroup': FeffPathGroup,
49 }
51def encode4js(obj):
52 """return an object ready for json encoding.
53 has special handling for many Python types
54 numpy array
55 complex numbers
56 Larch Groups
57 Larch Parameters
58 """
59 if obj is None:
60 return None
61 if isinstance(obj, np.ndarray):
62 out = {'__class__': 'Array', '__shape__': obj.shape,
63 '__dtype__': obj.dtype.name}
64 out['value'] = obj.flatten().tolist()
66 if 'complex' in obj.dtype.name:
67 out['value'] = [(obj.real).tolist(), (obj.imag).tolist()]
68 elif obj.dtype.name == 'object':
69 out['value'] = [encode4js(i) for i in out['value']]
70 return out
71 elif isinstance(obj, (bool, np.bool_)):
72 return bool(obj)
73 elif isinstance(obj, (int, np.int64, np.int32)):
74 return int(obj)
75 elif isinstance(obj, (float, np.float64, np.float32)):
76 return float(obj)
77 elif isinstance(obj, str):
78 return str(obj)
79 elif isinstance(obj, bytes):
80 return obj.decode('utf-8')
81 elif isinstance(obj, datetime):
82 return {'__class__': 'Datetime', 'isotime': obj.isoformat()}
83 elif isinstance(obj,(complex, np.complex128)):
84 return {'__class__': 'Complex', 'value': (obj.real, obj.imag)}
85 elif isinstance(obj, io.IOBase):
86 out ={'__class__': 'IOBase', 'class': obj.__class__.__name__,
87 'name': obj.name, 'closed': obj.closed,
88 'readable': obj.readable(), 'writable': False}
89 try:
90 out['writable'] = obj.writable()
91 except ValueError:
92 out['writable'] = False
93 return out
94 elif isinstance(obj, h5py.File):
95 return {'__class__': 'HDF5File',
96 'value': (obj.name, obj.filename, obj.mode, obj.libver),
97 'keys': list(obj.keys())}
98 elif isinstance(obj, h5py.Group):
99 return {'__class__': 'HDF5Group', 'value': (obj.name, obj.file.filename),
100 'keys': list(obj.keys())}
101 elif isinstance(obj, slice):
102 return {'__class__': 'Slice', 'value': (obj.start, obj.stop, obj.step)}
104 elif isinstance(obj, list):
105 return {'__class__': 'List', 'value': [encode4js(item) for item in obj]}
106 elif isinstance(obj, tuple):
107 if hasattr(obj, '_fields'): # named tuple!
108 return {'__class__': 'NamedTuple',
109 '__name__': obj.__class__.__name__,
110 '_fields': obj._fields,
111 'value': [encode4js(item) for item in obj]}
112 else:
113 return {'__class__': 'Tuple', 'value': [encode4js(item) for item in obj]}
114 elif isinstance(obj, dict):
115 out = {'__class__': 'Dict'}
116 for key, val in obj.items():
117 out[encode4js(key)] = encode4js(val)
118 return out
119 elif isinstance(obj, logging.Logger):
121 level = 'DEBUG'
122 for key, val in LoggingLevels.items():
123 if obj.level == val:
124 level = key
125 return {'__class__': 'Logger', 'name': obj.name, 'level': level}
126 elif isinstance(obj, MinimizerResult):
127 out = {'__class__': 'MinimizerResult'}
128 for attr in ('aborted', 'aic', 'bic', 'call_kws', 'chisqr',
129 'covar', 'errorbars', 'ier', 'init_vals',
130 'init_values', 'last_internal_values',
131 'lmdif_message', 'message', 'method', 'ndata', 'nfev',
132 'nfree', 'nvarys', 'params', 'redchi', 'residual',
133 'success', 'var_names'):
134 out[attr] = encode4js(getattr(obj, attr, None))
135 return out
136 elif isinstance(obj, Parameters):
137 out = {'__class__': 'Parameters'}
138 o_ast = obj._asteval
139 out['unique_symbols'] = {key: encode4js(o_ast.symtable[key])
140 for key in o_ast.user_defined_symbols()}
141 out['params'] = [(p.name, p.__getstate__()) for p in obj.values()]
142 return out
143 elif isinstance(obj, Parameter):
144 return {'__class__': 'Parameter', 'name': obj.name, 'state': obj.__getstate__()}
145 elif isinstance(obj, Model):
146 return {'__class__': 'Model', 'value': obj.dumps()}
147 elif isinstance(obj, ModelResult):
148 return {'__class__': 'ModelResult', 'value': obj.dumps()}
149 elif isgroup(obj):
150 try:
151 classname = obj.__class__.__name__
152 except:
153 classname = 'Group'
154 out = {'__class__': classname}
156 if classname == 'ParameterGroup': # save in order of parameter names
157 parnames = dir(obj)
158 for par in obj.__params__.keys():
159 if par in parnames:
160 out[par] = encode4js(getattr(obj, par))
161 else:
162 for item in dir(obj):
163 out[item] = encode4js(getattr(obj, item))
164 return out
165 elif hasattr(obj, '__getstate__') and not callable(obj):
166 return {'__class__': 'StatefulObject',
167 '__type__': obj.__class__.__name__,
168 'value': encode4js(obj.__getstate__())}
169 elif isinstance(obj, type):
170 return {'__class__': 'Type', 'value': repr(obj),
171 'module': getattr(obj, '__module__', None)}
172 elif callable(obj):
173 return {'__class__': 'Method', '__name__': repr(obj)}
174 elif hasattr(obj, 'dumps'):
175 print("Encode Warning: using dumps for ", obj)
176 return {'__class__': 'DumpableObject', 'value': obj.dumps()}
177 else:
178 print("Encode Warning: generic object dump for ", repr(obj))
179 out = {'__class__': 'Object', '__repr__': repr(obj),
180 '__classname__': obj.__class__.__name__}
181 for attr in dir(obj):
182 if attr.startswith('__') and attr.endswith('__'):
183 continue
184 thing = getattr(obj, attr)
185 if not callable(thing):
186 # print("will try to encode thing ", thing, type(thing))
187 out[attr] = encode4js(thing)
188 return out
190 return obj
192def decode4js(obj):
193 """
194 return decoded Python object from encoded object.
196 """
197 if not isinstance(obj, dict):
198 return obj
199 out = obj
200 classname = obj.pop('__class__', None)
201 if classname is None:
202 return obj
204 if classname == 'Complex':
205 out = obj['value'][0] + 1j*obj['value'][1]
206 elif classname in ('List', 'Tuple', 'NamedTuple'):
207 out = []
208 for item in obj['value']:
209 out.append(decode4js(item))
210 if classname == 'Tuple':
211 out = tuple(out)
212 elif classname == 'NamedTuple':
213 out = namedtuple(obj['__name__'], obj['_fields'])(*out)
214 elif classname == 'Array':
215 if obj['__dtype__'].startswith('complex'):
216 re = np.asarray(obj['value'][0], dtype='double')
217 im = np.asarray(obj['value'][1], dtype='double')
218 out = re + 1j*im
219 elif obj['__dtype__'].startswith('object'):
220 val = [decode4js(v) for v in obj['value']]
221 out = np.array(val, dtype=obj['__dtype__'])
223 else:
224 out = np.asarray(obj['value'], dtype=obj['__dtype__'])
225 out.shape = obj['__shape__']
226 elif classname in ('Dict', 'dict'):
227 out = {}
228 for key, val in obj.items():
229 out[key] = decode4js(val)
230 elif classname == 'Datetime':
231 obj = datetime.fromisoformat(obj['isotime'])
233 elif classname == 'IOBase':
234 mode = 'r'
235 if obj['readable'] and obj['writable']:
236 mode = 'a'
237 elif not obj['readable'] and obj['writable']:
238 mode = 'w'
239 out = open(obj['name'], mode=mode)
240 if obj['closed']:
241 out.close()
243 elif classname == 'Parameters':
244 out = Parameters()
245 out.clear()
246 unique_symbols = {key: decode4js(obj['unique_symbols'][key]) for key
247 in obj['unique_symbols']}
249 state = {'unique_symbols': unique_symbols, 'params': []}
250 for name, parstate in obj['params']:
251 par = Parameter(decode4js(name))
252 par.__setstate__(decode4js(parstate))
253 state['params'].append(par)
254 out.__setstate__(state)
255 elif classname in ('Parameter', 'parameter'):
256 name = decode4js(obj['name'])
257 state = decode4js(obj['state'])
258 out = Parameter(name)
259 out.__setstate__(state)
261 elif classname == 'Model':
262 mod = Model(lambda x: x)
263 out = mod.loads(decode4js(obj['value']))
265 elif classname == 'ModelResult':
266 params = Parameters()
267 res = ModelResult(Model(lambda x: x, None), params)
268 out = res.loads(decode4js(obj['value']))
270 elif classname == 'Logger':
271 out = getLogger(obj['name'], level=obj['level'])
273 elif classname == 'StatefulObject':
274 dtype = obj.get('__type__')
275 if dtype in HAS_STATE:
276 out = HAS_STATE[dtype]()
277 out.__setstate__(decode4js(obj.get('value')))
278 else:
279 print(f"Warning: cannot re-create stateful object of type '{dtype}'")
281 elif classname in LarchGroupTypes:
282 out = {}
283 for key, val in obj.items():
284 if (isinstance(val, dict) and
285 val.get('__class__', None) == 'Method' and
286 val.get('__name__', None) is not None):
287 pass # ignore class methods for subclassed Groups
288 else:
289 out[key] = decode4js(val)
290 if classname == 'FeffDatFile':
291 path = FeffDatFile()
292 path._set_from_dict(**out)
293 out = path
294 else:
295 out = LarchGroupTypes[classname](**out)
296 elif classname == 'Method':
297 mname = obj.get('__name__', '')
298 if 'ufunc' in mname:
299 mname = mname.replace('<ufunc', '').replace('>', '').replace("'","").strip()
300 out = SCIPY_FUNCTIONS.get(mname, None)
302 else:
303 print("cannot decode ", classname, obj)
304 return out