Coverage for src/lccalib/zp.py: 0%
175 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-04 10:22 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-04 10:22 +0000
1"""
2Tools to fit a zp from a color transformation
3"""
5from copy import deepcopy
6import numpy as np
7import pandas as pd
8import matplotlib.pyplot as plt
10from saltworks.plottools import binplot, make_bins
11from saltworks.linearmodels import RobustLinearSolver, linear_func, indic, LinearModel
12from saltworks.indextools import make_index
13import bbf.bspline as bs
15from .match import match
17# pylint: disable=invalid-name,too-many-locals,too-many-arguments
20def get_matched_cats(survey, star_lc_cat, logger, by=None, **kwargs):
21 """Align secondary star catalog with star lc catalog
23 :param survey: secondary star catalog provider
24 :param recarray star_lc_cat: lc averaged catalog
25 :param logger: logger
26 :param dict by: loop over values given by dict unique key
27 :return recarray star_selec: aligned secondary star catalog
28 :return recarray lc_selec: aligned lc star catalog
29 :return list index: selected index of secondary catalog
30 """
31 if by is None:
32 star_selec, lc_selec, index = _get_matched_cats(survey, star_lc_cat, **kwargs)
33 logger.info(f"number of stars in lc catalog: {len(star_lc_cat)}")
34 logger.info(f"number of stars considered in ref catalog: {len(index)}")
35 logger.info(f"number of match with ref catalog: {len(star_selec)}")
36 return star_selec, lc_selec
38 key = list(by.keys())[0]
39 stack_star = []
40 stack_lc = []
42 for k in by[key]:
43 _star_lc_cat = star_lc_cat[star_lc_cat[key] == k]
44 kwargs[key] = k
45 star_selec, lc_selec, index = _get_matched_cats(survey, _star_lc_cat, **kwargs)
46 logger.info(k)
47 logger.info(f"number of stars in lc catalog: {len(_star_lc_cat)}")
48 logger.info(f"number of stars considered in ref catalog: {len(index)}")
49 logger.info(f"number of match with ref catalog: {len(star_selec)}")
50 stack_star.append(star_selec)
51 stack_lc.append(lc_selec)
53 star_selec = np.hstack(stack_star)
54 lc_selec = np.hstack(stack_lc)
55 return star_selec, lc_selec
58def _get_matched_cats(survey, star_lc_cat, arcsecrad=1, **kwargs):
59 star_cat = survey.get_secondary_star_catalog(**kwargs)
60 if isinstance(star_cat, pd.DataFrame):
61 star_cat = star_cat.to_records()
63 selec = np.isfinite(star_lc_cat["ra"]) & np.isfinite(star_lc_cat["dec"])
64 if isinstance(star_lc_cat, pd.DataFrame):
65 star_lc_cat = star_lc_cat.to_records()
66 star_lc_cat = star_lc_cat[selec]
67 ra_bounds = star_lc_cat["ra"].min(), star_lc_cat["ra"].max()
68 dec_bounds = star_lc_cat["dec"].min(), star_lc_cat["dec"].max()
70 selec = (star_cat["ra"] > ra_bounds[0] - 0.2) & (
71 star_cat["ra"] < ra_bounds[1] + 0.2
72 )
73 selec &= (star_cat["dec"] > dec_bounds[0] - 0.2) & (
74 star_cat["dec"] < dec_bounds[1] + 0.2
75 )
76 star_cat = star_cat[selec]
78 index = match(star_lc_cat, star_cat, arcsecrad=arcsecrad)
79 star_selec = star_cat[index != -1]
80 lc_selec = star_lc_cat[index[index != -1]]
82 return star_selec, lc_selec, index
85# pylint: disable=dangerous-default-value
86def plot_diff_mag(
87 survey,
88 star_selec,
89 lc_selec,
90 xlabel="mag",
91 lims=[-0.2, 0.2],
92 fig=None,
93 axs=None,
94 **kwargs,
95):
96 """Plot mag difference between lc and secondary star catalogs.
98 :param survey: secondary star catalog provider
99 :param recarray star_selec: aligned secondary star catalog
100 :param recarray lc_selec: aligned light curve star catalog
101 :param str xlabel: can be mag or color
102 :return: fig, ax
103 """
104 nband = len(survey.bands)
105 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + 1, 2)
107 if axs is None:
108 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True)
109 axs = axs.flatten()
110 for ax, band in zip(axs, survey.bands):
111 labels = survey.get_secondary_labels(band, **kwargs)
112 mag = -2.5 * np.log10(lc_selec[f"flux_{band}"])
113 if labels["mag"] not in star_selec.dtype.names:
114 continue
115 mag0 = star_selec[labels["mag"]]
116 diff = mag - mag0
117 x = (
118 mag0
119 if xlabel == "mag"
120 else star_selec[labels[xlabel][0]] - star_selec[labels[xlabel][1]]
121 )
122 goods = np.isfinite(x)
123 binplot(x[goods], diff[goods] - np.nanmean(diff), ax=ax)
124 ax.grid("on")
125 ax.set_ylim(*lims)
126 ax.set_ylabel(f"{band}: aper-psf")
127 axs[-1].set_xlabel(xlabel)
128 if nband > 3:
129 axs[-2].set_xlabel(xlabel)
130 return fig, axs
133def _check_goods(mag0, mag1, c, weights, color_range=None):
134 """return good measurement flag."""
135 goods = (np.isfinite(mag0)) & (np.isfinite(mag1))
136 goods &= (np.isfinite(c)) & (np.isfinite(weights))
137 if color_range is not None:
138 goods &= (c > color_range[0]) & (c < color_range[1])
139 return goods
142def _zp_fit(mag0, mag1, c, weights, group_by, color_range=None, **kwargs):
143 """zp fit using the most basic inputs.
145 y = mag1 - mag0
146 """
147 goods = kwargs.get("goods", None)
148 fnl = kwargs.get("fnl", False)
149 fnl_grid = kwargs.get("fnl_grid", None)
150 fnl_order = kwargs.get("fnl_order", 4)
151 sky = kwargs.get("background", None)
152 fix_alpha = kwargs.get("fix_alpha", None)
153 _goods = _check_goods(mag0, mag1, c, weights, color_range=color_range)
154 if goods is not None:
155 _goods &= goods
157 y = np.array((mag1 - mag0))
158 w = np.array(weights) # w = np.ones((goods.sum()))
160 if fnl:
161 _m0 = np.array(mag0[_goods])
162 grid = np.linspace(_m0.min(), _m0.max(), 3)
163 if fnl_grid is not None:
164 grid = fnl_grid
165 spl = bs.BSpline(grid, order=fnl_order)
166 coo = spl.eval(_m0)
167 model = (
168 LinearModel(coo.row, coo.col, coo.data, name="fnl")
169 + linear_func(c[_goods], name="alpha")
170 + indic(np.array(group_by)[_goods], name="beta")
171 )
172 elif sky is not None:
173 model = (
174 linear_func(c[_goods], name="alpha")
175 + indic(np.array(group_by)[_goods], name="beta")
176 + linear_func(sky[_goods], name="csky")
177 )
178 else:
179 model = linear_func(c[_goods], name="alpha") + indic(
180 np.array(group_by)[_goods], name="beta"
181 )
183 if fix_alpha is not None:
184 model.params[0] = fix_alpha
185 model.params.fix(0)
187 solver = RobustLinearSolver(model, y[_goods], weights=w[_goods])
188 x = solver.robust_solution(nsig=3)
189 model.params.free = x
190 res = solver.get_res(y[_goods], model.params.full) # x)
191 err = np.sqrt(solver.get_cov().diagonal())
192 error_model = deepcopy(model.params)
193 error_model.free = err
194 return dict(
195 {
196 "y": y,
197 # "x": x,
198 "alpha": model.params["alpha"].full,
199 "zp": model.params["beta"].free[:],
200 "color": c,
201 "res": res,
202 "alpha_err": error_model["alpha"].full,
203 "zp_err": error_model["beta"].free[:],
204 "mag0": mag0,
205 "mag1": mag1,
206 "goods": _goods,
207 "bads": solver.bads,
208 "w": w,
209 "model": model(),
210 "wres": solver.get_wres(x=x),
211 "csky": None if sky is None else model.params["csky"].full,
212 "csky_err": None if sky is None else error_model["csky"].full,
213 "fnl": None if not fnl else model.params["fnl"].full,
214 "cov": solver.get_cov(),
215 }
216 )
219def compute_zp(
220 survey,
221 band,
222 lc_selec,
223 star_selec,
224 color_range,
225 zpkey="ccd",
226 error_floor=0,
227 background=None,
228 **kwargs,
229):
230 """Fit a zp per zpkey (like ccd, name) and a joined linear color term
232 :param survey: secondary star catalog provider
233 :param band: band name
234 :param recarray star_selec: aligned secondary star catalog
235 :param recarray lc_selec: aligned light curve star catalog
236 :param list color_range: color range on which the fit is done
237 :param str zpkey: column name of lc_selec on which zp apply
238 :param error_floor: additional error term
239 :param background: background correction in mag, None to ignore
240 :return dict dfit: dict with all fitted quantities
241 """
243 labels = survey.get_secondary_labels(band, **kwargs)
244 mag_psf = -2.5 * np.log10(lc_selec[f"flux_{band}"])
245 emag_psf = 1.08 * (lc_selec[f"eflux_{band}"] / lc_selec[f"flux_{band}"])
247 mag_ap = star_selec[labels["mag"]]
248 emag_ap = star_selec[labels["emag"]]
250 color = star_selec[labels["color"][0]] - star_selec[labels["color"][1]]
251 wcolor = np.sqrt(1 / (emag_psf**2 + emag_ap**2 + error_floor**2))
253 goods = np.ones((len(mag_ap)), dtype=bool)
254 for k, v in labels["goods"].items():
255 goods &= star_selec[k] > v
256 goods &= (mag_ap < labels["mag_cut"][1]) & (mag_ap > labels["mag_cut"][0])
258 return _zp_fit(
259 mag_ap,
260 mag_psf,
261 color,
262 wcolor,
263 lc_selec[zpkey], # group_by
264 color_range=color_range,
265 goods=goods,
266 background=background,
267 )
270# pylint: disable=dangerous-default-value
271def plot_zpfit_res(
272 zpfits,
273 xlabel="mag",
274 lims=[-0.03, 0.03],
275 fig=None,
276 axs=None,
277):
278 """Plot zp fit residuals.
281 :param str xlabel: can be mag or color
282 :return: fig, ax
283 """
284 bands = list(zpfits.keys())
286 nband = len(bands)
287 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + nband % 2, 2)
289 if axs is None:
290 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True)
291 axs = axs.flatten()
292 for ax, band in zip(axs, bands):
293 _dfit = zpfits[band]
294 binplot(
295 np.array(_dfit[xlabel][_dfit["goods"]]),
296 np.array(_dfit["res"]),
297 ax=ax,
298 data=True,
299 label=band,
300 )
301 ax.grid("on")
302 ax.set_ylim(*lims)
303 ax.legend()
304 axs[-1].set_xlabel(xlabel)
305 if nband > 3:
306 axs[-2].set_xlabel(xlabel)
307 plt.tight_layout()
308 return fig, axs
311def zpfit_diagnostic(dfit, nbins=15, e0=None, e1=None):
312 """Plot zpfit diagnostic including rms of the residual compared to
313 measurement error and chi2.
314 """
315 # rms vs predicted error, chi2
316 bads = dfit["bads"]
317 y = dfit["res"]
318 wres = dfit["wres"]
320 fig, ax = plt.subplots(2, 2, figsize=(15, 5), sharex="col")
321 ax = list(ax.flatten())
322 for x, (ax0, ax1), xlabel in zip(
323 [dfit["mag0"][dfit["goods"]], dfit["color"][dfit["goods"]]],
324 [[ax[0], ax[2]], [ax[1], ax[3]]],
325 ["mag", "color"],
326 ):
327 bins, xbinned, xerr, index = make_bins(x, y, nbins)
328 ngood = np.array([(~bads[e]).sum() for e in index])
330 mean_y2 = np.array([(y[e][~bads[e]] ** 2).sum() for e in index]) / ngood
331 mean_2y = (np.array([y[e][~bads[e]].sum() for e in index]) / ngood) ** 2
332 chi2 = np.array([(wres[e][~bads[e]] ** 2).sum() for e in index]) / (ngood - 1)
333 rms = np.sqrt(mean_y2 - mean_2y)
334 nmeas = np.array([len(y[e]) for e in index])
336 w2 = (
337 np.array(
338 [(1 / dfit["w"][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index]
339 )
340 / ngood
341 )
343 ax0.errorbar(xbinned, rms, xerr=xerr, ls="None", marker="+", label="res rms")
344 ax0.errorbar(
345 xbinned,
346 np.sqrt(w2),
347 xerr=xerr,
348 ls="None",
349 marker="+",
350 label="predicted errors",
351 )
353 if e0 is not None:
354 e20 = (
355 np.array(
356 [(dfit[e0][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index]
357 )
358 / ngood
359 )
360 ax0.errorbar(
361 xbinned,
362 np.sqrt(e20),
363 xerr=xerr,
364 alpha=0.5,
365 ls="--",
366 marker="+",
367 label=f"{e0} contrib.",
368 )
369 if e1 is not None:
370 e21 = (
371 np.array(
372 [(dfit[e1][dfit["goods"]][e][~bads[e]] ** 2).sum() for e in index]
373 )
374 / ngood
375 )
376 ax0.errorbar(
377 xbinned,
378 np.sqrt(e21),
379 xerr=xerr,
380 ls="--",
381 marker="+",
382 label=f"{e1} contrib.",
383 )
385 ax1.errorbar(
386 xbinned, chi2, xerr=xerr, ls="None", marker="+", label="chi2 / dof"
387 )
388 ax1.set_ylim(0, 0.1)
389 ax1.set_ylim(0.5, 10)
390 ax1.set_yscale("log")
391 ax0.legend()
392 ax1.legend()
393 ax1.axhline(y=1, color="k", ls="--")
394 ax1.set_xlabel(xlabel)
395 return fig, ax
398# def make_bins(x, nbins):
399# """Define nbins bin in x.
400# :param array x: x
401# :param int nbins: number of bins
402# :return array bins: bins limit
403# :return array xbinned: binned version of x
404# :return array xerr: bins size
405# :return array index: index of x corresponding to each bin
406# """
407# bins = np.linspace(x.min(), x.max() + abs(x.max() * 1e-7), nbins + 1)
408# yd = np.digitize(x, bins)
409# index = make_index(yd)
410# xbinned = 0.5 * (bins[:-1] + bins[1:])
411# usedbins = np.array(np.sort(list(set(yd)))) - 1
412# xbinned = xbinned[usedbins]
413# bins = bins[usedbins + 1]
414# xerr = np.array([bins, bins]) - np.array([xbinned, xbinned])
415# return bins, xbinned, xerr, index