Coverage for /usr/local/lib/python3.9/site-packages/lccalib/zp.py: 0%

175 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-02-04 10:22 +0000

1""" 

2Tools to fit a zp from a color transformation 

3""" 

4 

5from copy import deepcopy 

6import numpy as np 

7import pandas as pd 

8import matplotlib.pyplot as plt 

9 

10from saltworks.plottools import binplot, make_bins 

11from saltworks.linearmodels import RobustLinearSolver, linear_func, indic, LinearModel 

12from saltworks.indextools import make_index 

13import bbf.bspline as bs 

14 

15from .match import match 

16 

17# pylint: disable=invalid-name,too-many-locals,too-many-arguments 

18 

19 

20def get_matched_cats(survey, star_lc_cat, logger, by=None, **kwargs): 

21 """Align secondary star catalog with star lc catalog 

22 

23 :param survey: secondary star catalog provider 

24 :param recarray star_lc_cat: lc averaged catalog 

25 :param logger: logger 

26 :param dict by: loop over values given by dict unique key 

27 :return recarray star_selec: aligned secondary star catalog 

28 :return recarray lc_selec: aligned lc star catalog 

29 :return list index: selected index of secondary catalog 

30 """ 

31 if by is None: 

32 star_selec, lc_selec, index = _get_matched_cats(survey, star_lc_cat, **kwargs) 

33 logger.info(f"number of stars in lc catalog: {len(star_lc_cat)}") 

34 logger.info(f"number of stars considered in ref catalog: {len(index)}") 

35 logger.info(f"number of match with ref catalog: {len(star_selec)}") 

36 return star_selec, lc_selec 

37 

38 key = list(by.keys())[0] 

39 stack_star = [] 

40 stack_lc = [] 

41 

42 for k in by[key]: 

43 _star_lc_cat = star_lc_cat[star_lc_cat[key] == k] 

44 kwargs[key] = k 

45 star_selec, lc_selec, index = _get_matched_cats(survey, _star_lc_cat, **kwargs) 

46 logger.info(k) 

47 logger.info(f"number of stars in lc catalog: {len(_star_lc_cat)}") 

48 logger.info(f"number of stars considered in ref catalog: {len(index)}") 

49 logger.info(f"number of match with ref catalog: {len(star_selec)}") 

50 stack_star.append(star_selec) 

51 stack_lc.append(lc_selec) 

52 

53 star_selec = np.hstack(stack_star) 

54 lc_selec = np.hstack(stack_lc) 

55 return star_selec, lc_selec 

56 

57 

58def _get_matched_cats(survey, star_lc_cat, arcsecrad=1, **kwargs): 

59 star_cat = survey.get_secondary_star_catalog(**kwargs) 

60 if isinstance(star_cat, pd.DataFrame): 

61 star_cat = star_cat.to_records() 

62 

63 selec = np.isfinite(star_lc_cat["ra"]) & np.isfinite(star_lc_cat["dec"]) 

64 if isinstance(star_lc_cat, pd.DataFrame): 

65 star_lc_cat = star_lc_cat.to_records() 

66 star_lc_cat = star_lc_cat[selec] 

67 ra_bounds = star_lc_cat["ra"].min(), star_lc_cat["ra"].max() 

68 dec_bounds = star_lc_cat["dec"].min(), star_lc_cat["dec"].max() 

69 

70 selec = (star_cat["ra"] > ra_bounds[0] - 0.2) & ( 

71 star_cat["ra"] < ra_bounds[1] + 0.2 

72 ) 

73 selec &= (star_cat["dec"] > dec_bounds[0] - 0.2) & ( 

74 star_cat["dec"] < dec_bounds[1] + 0.2 

75 ) 

76 star_cat = star_cat[selec] 

77 

78 index = match(star_lc_cat, star_cat, arcsecrad=arcsecrad) 

79 star_selec = star_cat[index != -1] 

80 lc_selec = star_lc_cat[index[index != -1]] 

81 

82 return star_selec, lc_selec, index 

83 

84 

85# pylint: disable=dangerous-default-value 

86def plot_diff_mag( 

87 survey, 

88 star_selec, 

89 lc_selec, 

90 xlabel="mag", 

91 lims=[-0.2, 0.2], 

92 fig=None, 

93 axs=None, 

94 **kwargs, 

95): 

96 """Plot mag difference between lc and secondary star catalogs. 

97 

98 :param survey: secondary star catalog provider 

99 :param recarray star_selec: aligned secondary star catalog 

100 :param recarray lc_selec: aligned light curve star catalog 

101 :param str xlabel: can be mag or color 

102 :return: fig, ax 

103 """ 

104 nband = len(survey.bands) 

105 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + 1, 2) 

106 

107 if axs is None: 

108 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True) 

109 axs = axs.flatten() 

110 for ax, band in zip(axs, survey.bands): 

111 labels = survey.get_secondary_labels(band, **kwargs) 

112 mag = -2.5 * np.log10(lc_selec[f"flux_{band}"]) 

113 if labels["mag"] not in star_selec.dtype.names: 

114 continue 

115 mag0 = star_selec[labels["mag"]] 

116 diff = mag - mag0 

117 x = ( 

118 mag0 

119 if xlabel == "mag" 

120 else star_selec[labels[xlabel][0]] - star_selec[labels[xlabel][1]] 

121 ) 

122 goods = np.isfinite(x) 

123 binplot(x[goods], diff[goods] - np.nanmean(diff), ax=ax) 

124 ax.grid("on") 

125 ax.set_ylim(*lims) 

126 ax.set_ylabel(f"{band}: aper-psf") 

127 axs[-1].set_xlabel(xlabel) 

128 if nband > 3: 

129 axs[-2].set_xlabel(xlabel) 

130 return fig, axs 

131 

132 

133def _check_goods(mag0, mag1, c, weights, color_range=None): 

134 """return good measurement flag.""" 

135 goods = (np.isfinite(mag0)) & (np.isfinite(mag1)) 

136 goods &= (np.isfinite(c)) & (np.isfinite(weights)) 

137 if color_range is not None: 

138 goods &= (c > color_range[0]) & (c < color_range[1]) 

139 return goods 

140 

141 

142def _zp_fit(mag0, mag1, c, weights, group_by, color_range=None, **kwargs): 

143 """zp fit using the most basic inputs. 

144 

145 y = mag1 - mag0 

146 """ 

147 goods = kwargs.get("goods", None) 

148 fnl = kwargs.get("fnl", False) 

149 fnl_grid = kwargs.get("fnl_grid", None) 

150 fnl_order = kwargs.get("fnl_order", 4) 

151 sky = kwargs.get("background", None) 

152 fix_alpha = kwargs.get("fix_alpha", None) 

153 _goods = _check_goods(mag0, mag1, c, weights, color_range=color_range) 

154 if goods is not None: 

155 _goods &= goods 

156 

157 y = np.array((mag1 - mag0)) 

158 w = np.array(weights) # w = np.ones((goods.sum())) 

159 

160 if fnl: 

161 _m0 = np.array(mag0[_goods]) 

162 grid = np.linspace(_m0.min(), _m0.max(), 3) 

163 if fnl_grid is not None: 

164 grid = fnl_grid 

165 spl = bs.BSpline(grid, order=fnl_order) 

166 coo = spl.eval(_m0) 

167 model = ( 

168 LinearModel(coo.row, coo.col, coo.data, name="fnl") 

169 + linear_func(c[_goods], name="alpha") 

170 + indic(np.array(group_by)[_goods], name="beta") 

171 ) 

172 elif sky is not None: 

173 model = ( 

174 linear_func(c[_goods], name="alpha") 

175 + indic(np.array(group_by)[_goods], name="beta") 

176 + linear_func(sky[_goods], name="csky") 

177 ) 

178 else: 

179 model = linear_func(c[_goods], name="alpha") + indic( 

180 np.array(group_by)[_goods], name="beta" 

181 ) 

182 

183 if fix_alpha is not None: 

184 model.params[0] = fix_alpha 

185 model.params.fix(0) 

186 

187 solver = RobustLinearSolver(model, y[_goods], weights=w[_goods]) 

188 x = solver.robust_solution(nsig=3) 

189 model.params.free = x 

190 res = solver.get_res(y[_goods], model.params.full) # x) 

191 err = np.sqrt(solver.get_cov().diagonal()) 

192 error_model = deepcopy(model.params) 

193 error_model.free = err 

194 return dict( 

195 { 

196 "y": y, 

197 # "x": x, 

198 "alpha": model.params["alpha"].full, 

199 "zp": model.params["beta"].free[:], 

200 "color": c, 

201 "res": res, 

202 "alpha_err": error_model["alpha"].full, 

203 "zp_err": error_model["beta"].free[:], 

204 "mag0": mag0, 

205 "mag1": mag1, 

206 "goods": _goods, 

207 "bads": solver.bads, 

208 "w": w, 

209 "model": model(), 

210 "wres": solver.get_wres(x=x), 

211 "csky": None if sky is None else model.params["csky"].full, 

212 "csky_err": None if sky is None else error_model["csky"].full, 

213 "fnl": None if not fnl else model.params["fnl"].full, 

214 "cov": solver.get_cov(), 

215 } 

216 ) 

217 

218 

219def compute_zp( 

220 survey, 

221 band, 

222 lc_selec, 

223 star_selec, 

224 color_range, 

225 zpkey="ccd", 

226 error_floor=0, 

227 background=None, 

228 **kwargs, 

229): 

230 """Fit a zp per zpkey (like ccd, name) and a joined linear color term 

231 

232 :param survey: secondary star catalog provider 

233 :param band: band name 

234 :param recarray star_selec: aligned secondary star catalog 

235 :param recarray lc_selec: aligned light curve star catalog 

236 :param list color_range: color range on which the fit is done 

237 :param str zpkey: column name of lc_selec on which zp apply 

238 :param error_floor: additional error term 

239 :param background: background correction in mag, None to ignore 

240 :return dict dfit: dict with all fitted quantities 

241 """ 

242 

243 labels = survey.get_secondary_labels(band, **kwargs) 

244 mag_psf = -2.5 * np.log10(lc_selec[f"flux_{band}"]) 

245 emag_psf = 1.08 * (lc_selec[f"eflux_{band}"] / lc_selec[f"flux_{band}"]) 

246 

247 mag_ap = star_selec[labels["mag"]] 

248 emag_ap = star_selec[labels["emag"]] 

249 

250 color = star_selec[labels["color"][0]] - star_selec[labels["color"][1]] 

251 wcolor = np.sqrt(1 / (emag_psf**2 + emag_ap**2 + error_floor**2)) 

252 

253 goods = np.ones((len(mag_ap)), dtype=bool) 

254 for k, v in labels["goods"].items(): 

255 goods &= star_selec[k] > v 

256 goods &= (mag_ap < labels["mag_cut"][1]) & (mag_ap > labels["mag_cut"][0]) 

257 

258 return _zp_fit( 

259 mag_ap, 

260 mag_psf, 

261 color, 

262 wcolor, 

263 lc_selec[zpkey], # group_by 

264 color_range=color_range, 

265 goods=goods, 

266 background=background, 

267 ) 

268 

269 

270# pylint: disable=dangerous-default-value 

271def plot_zpfit_res( 

272 zpfits, 

273 xlabel="mag", 

274 lims=[-0.03, 0.03], 

275 fig=None, 

276 axs=None, 

277): 

278 """Plot zp fit residuals. 

279 

280 

281 :param str xlabel: can be mag or color 

282 :return: fig, ax 

283 """ 

284 bands = list(zpfits.keys()) 

285 

286 nband = len(bands) 

287 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + nband % 2, 2) 

288 

289 if axs is None: 

290 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True) 

291 axs = axs.flatten() 

292 for ax, band in zip(axs, bands): 

293 _dfit = zpfits[band] 

294 binplot( 

295 np.array(_dfit[xlabel][_dfit["goods"]]), 

296 np.array(_dfit["res"]), 

297 ax=ax, 

298 data=True, 

299 label=band, 

300 ) 

301 ax.grid("on") 

302 ax.set_ylim(*lims) 

303 ax.legend() 

304 axs[-1].set_xlabel(xlabel) 

305 if nband > 3: 

306 axs[-2].set_xlabel(xlabel) 

307 plt.tight_layout() 

308 return fig, axs 

309 

310 

311def zpfit_diagnostic(dfit, nbins=15, e0=None, e1=None): 

312 """Plot zpfit diagnostic including rms of the residual compared to 

313 measurement error and chi2. 

314 """ 

315 # rms vs predicted error, chi2 

316 bads = dfit["bads"] 

317 y = dfit["res"] 

318 wres = dfit["wres"] 

319 

320 fig, ax = plt.subplots(2, 2, figsize=(15, 5), sharex="col") 

321 ax = list(ax.flatten()) 

322 for x, (ax0, ax1), xlabel in zip( 

323 [dfit["mag0"][dfit["goods"]], dfit["color"][dfit["goods"]]], 

324 [[ax[0], ax[2]], [ax[1], ax[3]]], 

325 ["mag", "color"], 

326 ): 

327 bins, xbinned, xerr, index = make_bins(x, y, nbins) 

328 ngood = np.array([(~bads[e]).sum() for e in index]) 

329 

330 mean_y2 = np.array([(y[e][~bads[e]] ** 2).sum() for e in index]) / ngood 

331 mean_2y = (np.array([y[e][~bads[e]].sum() for e in index]) / ngood) ** 2 

332 chi2 = np.array([(wres[e][~bads[e]] ** 2).sum() for e in index]) / (ngood - 1) 

333 rms = np.sqrt(mean_y2 - mean_2y) 

334 nmeas = np.array([len(y[e]) for e in index]) 

335 

336 w2 = ( 

337 np.array( 

338 [(1 / dfit["w"][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index] 

339 ) 

340 / ngood 

341 ) 

342 

343 ax0.errorbar(xbinned, rms, xerr=xerr, ls="None", marker="+", label="res rms") 

344 ax0.errorbar( 

345 xbinned, 

346 np.sqrt(w2), 

347 xerr=xerr, 

348 ls="None", 

349 marker="+", 

350 label="predicted errors", 

351 ) 

352 

353 if e0 is not None: 

354 e20 = ( 

355 np.array( 

356 [(dfit[e0][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index] 

357 ) 

358 / ngood 

359 ) 

360 ax0.errorbar( 

361 xbinned, 

362 np.sqrt(e20), 

363 xerr=xerr, 

364 alpha=0.5, 

365 ls="--", 

366 marker="+", 

367 label=f"{e0} contrib.", 

368 ) 

369 if e1 is not None: 

370 e21 = ( 

371 np.array( 

372 [(dfit[e1][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index] 

373 ) 

374 / ngood 

375 ) 

376 ax0.errorbar( 

377 xbinned, 

378 np.sqrt(e21), 

379 xerr=xerr, 

380 ls="--", 

381 marker="+", 

382 label=f"{e1} contrib.", 

383 ) 

384 

385 ax1.errorbar( 

386 xbinned, chi2, xerr=xerr, ls="None", marker="+", label="chi2 / dof" 

387 ) 

388 ax1.set_ylim(0, 0.1) 

389 ax1.set_ylim(0.5, 10) 

390 ax1.set_yscale("log") 

391 ax0.legend() 

392 ax1.legend() 

393 ax1.axhline(y=1, color="k", ls="--") 

394 ax1.set_xlabel(xlabel) 

395 return fig, ax 

396 

397 

398# def make_bins(x, nbins): 

399# """Define nbins bin in x. 

400# :param array x: x 

401# :param int nbins: number of bins 

402# :return array bins: bins limit 

403# :return array xbinned: binned version of x 

404# :return array xerr: bins size 

405# :return array index: index of x corresponding to each bin 

406# """ 

407# bins = np.linspace(x.min(), x.max() + abs(x.max() * 1e-7), nbins + 1) 

408# yd = np.digitize(x, bins) 

409# index = make_index(yd) 

410# xbinned = 0.5 * (bins[:-1] + bins[1:]) 

411# usedbins = np.array(np.sort(list(set(yd)))) - 1 

412# xbinned = xbinned[usedbins] 

413# bins = bins[usedbins + 1] 

414# xerr = np.array([bins, bins]) - np.array([xbinned, xbinned]) 

415# return bins, xbinned, xerr, index