Coverage for /Users/Newville/Codes/xraylarch/larch/utils/jsonutils.py: 56%

224 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-09 10:08 -0600

1#!/usr/bin/env python 

2""" 

3 json utilities for larch objects 

4""" 

5import json 

6import io 

7import numpy as np 

8import h5py 

9from datetime import datetime 

10from collections import namedtuple 

11import logging 

12 

13HAS_STATE = {} 

14try: 

15 from sklearn.cross_decomposition import PLSRegression 

16 from sklearn.linear_model import LassoLarsCV, LassoLars, Lasso 

17 HAS_STATE.update({'PLSRegression': PLSRegression, 

18 'LassoLarsCV':LassoLarsCV, 

19 'LassoLars': LassoLars, 'Lasso': Lasso}) 

20 

21except ImportError: 

22 pass 

23 

24from lmfit import Parameter, Parameters 

25from lmfit.model import Model, ModelResult 

26from lmfit.minimizer import Minimizer, MinimizerResult 

27from lmfit.parameter import SCIPY_FUNCTIONS 

28 

29from larch import Group, isgroup, Journal, ParameterGroup 

30 

31from larch.xafs import FeffitDataSet, FeffDatFile, FeffPathGroup, TransformGroup 

32from larch.xafs.feffutils import FeffCalcResults 

33from larch.utils.strutils import bytes2str, str2bytes, fix_varname 

34from larch.utils.logging import getLogger 

35from larch.utils.logging import _levels as LoggingLevels 

36 

37HAS_STATE['FeffCalcResults'] = FeffCalcResults 

38HAS_STATE['FeffDatFile'] = FeffDatFile 

39HAS_STATE['FeffPathGroup'] = FeffPathGroup 

40HAS_STATE['Journal'] = Journal 

41 

42LarchGroupTypes = {'Group': Group, 

43 'ParameterGroup': ParameterGroup, 

44 'FeffitDataSet': FeffitDataSet, 

45 'TransformGroup': TransformGroup, 

46 'MinimizerResult': MinimizerResult, 

47 'FeffDatFile': FeffDatFile, 

48 'FeffPathGroup': FeffPathGroup, 

49 } 

50 

51def encode4js(obj): 

52 """return an object ready for json encoding. 

53 has special handling for many Python types 

54 numpy array 

55 complex numbers 

56 Larch Groups 

57 Larch Parameters 

58 """ 

59 if obj is None: 

60 return None 

61 if isinstance(obj, np.ndarray): 

62 out = {'__class__': 'Array', '__shape__': obj.shape, 

63 '__dtype__': obj.dtype.name} 

64 out['value'] = obj.flatten().tolist() 

65 

66 if 'complex' in obj.dtype.name: 

67 out['value'] = [(obj.real).tolist(), (obj.imag).tolist()] 

68 elif obj.dtype.name == 'object': 

69 out['value'] = [encode4js(i) for i in out['value']] 

70 return out 

71 elif isinstance(obj, (bool, np.bool_)): 

72 return bool(obj) 

73 elif isinstance(obj, (int, np.int64, np.int32)): 

74 return int(obj) 

75 elif isinstance(obj, (float, np.float64, np.float32)): 

76 return float(obj) 

77 elif isinstance(obj, str): 

78 return str(obj) 

79 elif isinstance(obj, bytes): 

80 return obj.decode('utf-8') 

81 elif isinstance(obj, datetime): 

82 return {'__class__': 'Datetime', 'isotime': obj.isoformat()} 

83 elif isinstance(obj,(complex, np.complex128)): 

84 return {'__class__': 'Complex', 'value': (obj.real, obj.imag)} 

85 elif isinstance(obj, io.IOBase): 

86 out ={'__class__': 'IOBase', 'class': obj.__class__.__name__, 

87 'name': obj.name, 'closed': obj.closed, 

88 'readable': obj.readable(), 'writable': False} 

89 try: 

90 out['writable'] = obj.writable() 

91 except ValueError: 

92 out['writable'] = False 

93 return out 

94 elif isinstance(obj, h5py.File): 

95 return {'__class__': 'HDF5File', 

96 'value': (obj.name, obj.filename, obj.mode, obj.libver), 

97 'keys': list(obj.keys())} 

98 elif isinstance(obj, h5py.Group): 

99 return {'__class__': 'HDF5Group', 'value': (obj.name, obj.file.filename), 

100 'keys': list(obj.keys())} 

101 elif isinstance(obj, slice): 

102 return {'__class__': 'Slice', 'value': (obj.start, obj.stop, obj.step)} 

103 

104 elif isinstance(obj, list): 

105 return {'__class__': 'List', 'value': [encode4js(item) for item in obj]} 

106 elif isinstance(obj, tuple): 

107 if hasattr(obj, '_fields'): # named tuple! 

108 return {'__class__': 'NamedTuple', 

109 '__name__': obj.__class__.__name__, 

110 '_fields': obj._fields, 

111 'value': [encode4js(item) for item in obj]} 

112 else: 

113 return {'__class__': 'Tuple', 'value': [encode4js(item) for item in obj]} 

114 elif isinstance(obj, dict): 

115 out = {'__class__': 'Dict'} 

116 for key, val in obj.items(): 

117 out[encode4js(key)] = encode4js(val) 

118 return out 

119 elif isinstance(obj, logging.Logger): 

120 

121 level = 'DEBUG' 

122 for key, val in LoggingLevels.items(): 

123 if obj.level == val: 

124 level = key 

125 return {'__class__': 'Logger', 'name': obj.name, 'level': level} 

126 elif isinstance(obj, MinimizerResult): 

127 out = {'__class__': 'MinimizerResult'} 

128 for attr in ('aborted', 'aic', 'bic', 'call_kws', 'chisqr', 

129 'covar', 'errorbars', 'ier', 'init_vals', 

130 'init_values', 'last_internal_values', 

131 'lmdif_message', 'message', 'method', 'ndata', 'nfev', 

132 'nfree', 'nvarys', 'params', 'redchi', 'residual', 

133 'success', 'var_names'): 

134 out[attr] = encode4js(getattr(obj, attr, None)) 

135 return out 

136 elif isinstance(obj, Parameters): 

137 out = {'__class__': 'Parameters'} 

138 o_ast = obj._asteval 

139 out['unique_symbols'] = {key: encode4js(o_ast.symtable[key]) 

140 for key in o_ast.user_defined_symbols()} 

141 out['params'] = [(p.name, p.__getstate__()) for p in obj.values()] 

142 return out 

143 elif isinstance(obj, Parameter): 

144 return {'__class__': 'Parameter', 'name': obj.name, 'state': obj.__getstate__()} 

145 elif isinstance(obj, Model): 

146 return {'__class__': 'Model', 'value': obj.dumps()} 

147 elif isinstance(obj, ModelResult): 

148 return {'__class__': 'ModelResult', 'value': obj.dumps()} 

149 elif isgroup(obj): 

150 try: 

151 classname = obj.__class__.__name__ 

152 except: 

153 classname = 'Group' 

154 out = {'__class__': classname} 

155 

156 if classname == 'ParameterGroup': # save in order of parameter names 

157 parnames = dir(obj) 

158 for par in obj.__params__.keys(): 

159 if par in parnames: 

160 out[par] = encode4js(getattr(obj, par)) 

161 else: 

162 for item in dir(obj): 

163 out[item] = encode4js(getattr(obj, item)) 

164 return out 

165 elif hasattr(obj, '__getstate__') and not callable(obj): 

166 return {'__class__': 'StatefulObject', 

167 '__type__': obj.__class__.__name__, 

168 'value': encode4js(obj.__getstate__())} 

169 elif isinstance(obj, type): 

170 return {'__class__': 'Type', 'value': repr(obj), 

171 'module': getattr(obj, '__module__', None)} 

172 elif callable(obj): 

173 return {'__class__': 'Method', '__name__': repr(obj)} 

174 elif hasattr(obj, 'dumps'): 

175 print("Encode Warning: using dumps for ", obj) 

176 return {'__class__': 'DumpableObject', 'value': obj.dumps()} 

177 else: 

178 print("Encode Warning: generic object dump for ", repr(obj)) 

179 out = {'__class__': 'Object', '__repr__': repr(obj), 

180 '__classname__': obj.__class__.__name__} 

181 for attr in dir(obj): 

182 if attr.startswith('__') and attr.endswith('__'): 

183 continue 

184 thing = getattr(obj, attr) 

185 if not callable(thing): 

186 # print("will try to encode thing ", thing, type(thing)) 

187 out[attr] = encode4js(thing) 

188 return out 

189 

190 return obj 

191 

192def decode4js(obj): 

193 """ 

194 return decoded Python object from encoded object. 

195 

196 """ 

197 if not isinstance(obj, dict): 

198 return obj 

199 out = obj 

200 classname = obj.pop('__class__', None) 

201 if classname is None: 

202 return obj 

203 

204 if classname == 'Complex': 

205 out = obj['value'][0] + 1j*obj['value'][1] 

206 elif classname in ('List', 'Tuple', 'NamedTuple'): 

207 out = [] 

208 for item in obj['value']: 

209 out.append(decode4js(item)) 

210 if classname == 'Tuple': 

211 out = tuple(out) 

212 elif classname == 'NamedTuple': 

213 out = namedtuple(obj['__name__'], obj['_fields'])(*out) 

214 elif classname == 'Array': 

215 if obj['__dtype__'].startswith('complex'): 

216 re = np.asarray(obj['value'][0], dtype='double') 

217 im = np.asarray(obj['value'][1], dtype='double') 

218 out = re + 1j*im 

219 elif obj['__dtype__'].startswith('object'): 

220 val = [decode4js(v) for v in obj['value']] 

221 out = np.array(val, dtype=obj['__dtype__']) 

222 

223 else: 

224 out = np.asarray(obj['value'], dtype=obj['__dtype__']) 

225 out.shape = obj['__shape__'] 

226 elif classname in ('Dict', 'dict'): 

227 out = {} 

228 for key, val in obj.items(): 

229 out[key] = decode4js(val) 

230 elif classname == 'Datetime': 

231 obj = datetime.fromisoformat(obj['isotime']) 

232 

233 elif classname == 'IOBase': 

234 mode = 'r' 

235 if obj['readable'] and obj['writable']: 

236 mode = 'a' 

237 elif not obj['readable'] and obj['writable']: 

238 mode = 'w' 

239 out = open(obj['name'], mode=mode) 

240 if obj['closed']: 

241 out.close() 

242 

243 elif classname == 'Parameters': 

244 out = Parameters() 

245 out.clear() 

246 unique_symbols = {key: decode4js(obj['unique_symbols'][key]) for key 

247 in obj['unique_symbols']} 

248 

249 state = {'unique_symbols': unique_symbols, 'params': []} 

250 for name, parstate in obj['params']: 

251 par = Parameter(decode4js(name)) 

252 par.__setstate__(decode4js(parstate)) 

253 state['params'].append(par) 

254 out.__setstate__(state) 

255 elif classname in ('Parameter', 'parameter'): 

256 name = decode4js(obj['name']) 

257 state = decode4js(obj['state']) 

258 out = Parameter(name) 

259 out.__setstate__(state) 

260 

261 elif classname == 'Model': 

262 mod = Model(lambda x: x) 

263 out = mod.loads(decode4js(obj['value'])) 

264 

265 elif classname == 'ModelResult': 

266 params = Parameters() 

267 res = ModelResult(Model(lambda x: x, None), params) 

268 out = res.loads(decode4js(obj['value'])) 

269 

270 elif classname == 'Logger': 

271 out = getLogger(obj['name'], level=obj['level']) 

272 

273 elif classname == 'StatefulObject': 

274 dtype = obj.get('__type__') 

275 if dtype in HAS_STATE: 

276 out = HAS_STATE[dtype]() 

277 out.__setstate__(decode4js(obj.get('value'))) 

278 else: 

279 print(f"Warning: cannot re-create stateful object of type '{dtype}'") 

280 

281 elif classname in LarchGroupTypes: 

282 out = {} 

283 for key, val in obj.items(): 

284 if (isinstance(val, dict) and 

285 val.get('__class__', None) == 'Method' and 

286 val.get('__name__', None) is not None): 

287 pass # ignore class methods for subclassed Groups 

288 else: 

289 out[key] = decode4js(val) 

290 if classname == 'FeffDatFile': 

291 path = FeffDatFile() 

292 path._set_from_dict(**out) 

293 out = path 

294 else: 

295 out = LarchGroupTypes[classname](**out) 

296 elif classname == 'Method': 

297 mname = obj.get('__name__', '') 

298 if 'ufunc' in mname: 

299 mname = mname.replace('<ufunc', '').replace('>', '').replace("'","").strip() 

300 out = SCIPY_FUNCTIONS.get(mname, None) 

301 

302 else: 

303 print("cannot decode ", classname, obj) 

304 return out