Coverage for src/lccalib/averaging.py: 0%
267 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-04 10:22 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-04 10:22 +0000
1"""
2Tools to make a single light curve catalog with averaged flux
3"""
5import os
6import itertools
7import ssl
8import logging
9import pandas as pd
10import numpy as np
11from scipy.stats import norm
12import matplotlib.pyplot as plt
14from astroquery.gaia import GaiaClass
15from astropy.table import vstack
17from saltworks.linearmodels import LinearModel, RobustLinearSolver
18from saltworks.plottools import binplot
19from saltworks.dataproxy import DataProxy
21from .match import match
23# pylint: disable=invalid-name,too-many-locals, too-many-arguments,dangerous-default-value
26def get_gaia_match(stars, offset=0.02, maxq=5000):
27 """Return aligned Gaia and input star subset catalog.
29 TODO: apply pm.
30 """
31 # pylint: disable=protected-access
32 _create_unverified_https_context = ssl._create_unverified_context
33 ssl._create_default_https_context = _create_unverified_https_context
35 gaia = GaiaClass(
36 gaia_tap_server="https://gea.esac.esa.int/",
37 gaia_data_server="https://gea.esac.esa.int/",
38 )
39 ra_min, dec_min = stars["ra"].min() - offset, stars["dec"].min() - offset
40 ra_max, dec_max = stars["ra"].max() + offset, stars["dec"].max() + offset
42 # pylint: disable=line-too-long
43 # query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where has_xp_continuous = 'True' and ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}"
44 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}"
45 # job = gaia.launch_job_async(query, dump_to_file=False)
46 job = gaia.launch_job(query, dump_to_file=False)
47 gaia_ids = job.get_results()
48 if len(gaia_ids) > maxq - 100: # split in 4 if close to 5000, but no more
49 gaia_ids = []
50 ra2 = (ra_max + ra_min) / 2
51 dec2 = (dec_max + dec_min) / 2
52 for _ra_min, _ra_max, _dec_min, _dec_max in [
53 (ra_min, ra2, dec_min, dec2),
54 (ra_min, ra2, dec2, dec_max),
55 (ra2, ra_max, dec_min, dec2),
56 (ra2, ra_max, dec2, dec_max),
57 ]:
58 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {_ra_max} and ra >= {_ra_min} and dec <= {_dec_max} and dec >= {_dec_min}"
59 job = gaia.launch_job(query, dump_to_file=False)
60 gaia_ids.append(job.get_results())
61 gaia_ids = vstack(gaia_ids)
62 index = match(gaia_ids, stars, arcsecrad=20)
63 selected_stars = stars[index != -1]
64 selected_ids = gaia_ids[index[index != -1]]
65 return selected_ids.to_pandas(), selected_stars, index
68def cut_from_epoch_number(T, min_d=3):
69 """cut stars with less than 3 epochs"""
70 if min_d < 1:
71 return T
72 nstar = int(T["index"].max() + 1)
73 dates = set(T["mjd"].astype(int))
74 ndate = len(dates)
75 D = np.zeros((nstar, ndate))
76 vdate = T["mjd"].astype(int)
77 for i, d in enumerate(dates):
78 dt = T[vdate == d]
79 D[dt["index"].astype(int), i] += 1
80 ikeep = np.sum(D, axis=1) > min_d
81 # print(("cut %d/%d" % (len(ikeep) - ikeep.sum(), len(ikeep))))
82 T = T[ikeep[T["index"].astype(int)]]
83 return T
86def star_lc_averager(
87 T, star_key="star", flux_key="flux", eflux_key="error", error_floor=0, show=False
88):
89 """Compute average flux and associated errors."""
90 # pylint: disable=E1101
91 # pylint: disable=E1130
92 dp = DataProxy(T, flux=flux_key, eflux=eflux_key)
93 dp.add_field("star", T[star_key].astype(int))
94 dp.make_index("star", intmap=True)
96 weights = 1.0 / np.sqrt(dp.eflux**2 + (error_floor*dp.flux)**2)
98 model = LinearModel(list(range(len(dp.nt))), dp.star_index, np.ones_like(dp.flux))
99 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights))
100 avg_flux = solver.robust_solution(nsig=3)
101 solver.model.params.free = avg_flux
102 res = solver.get_res(dp.flux)
103 wres = solver.get_wres(avg_flux)
104 ngood = np.bincount(dp.star_index, ~solver.bads)
105 nz = 0
106 while ngood[-1] == 0:
107 ngood = ngood[:-1]
108 nz += 1
109 index = dp.star_index[~solver.bads]
110 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood
111 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2
112 chi2 = np.bincount(index, weights=wres[~solver.bads] ** 2) / (ngood - 1)
113 err = np.sqrt(solver.get_cov().diagonal())
114 with np.errstate(divide="ignore", invalid="ignore"):
115 avg_cat = pd.DataFrame(
116 data={
117 "star": dp.star_set if not nz else dp.star_set[:-nz],
118 "flux": avg_flux if not nz else avg_flux[:-nz],
119 "eflux": err if not nz else err[:-nz],
120 "rms": np.sqrt(mean_y2 - mean_2y),
121 "nmeas": np.bincount(index),
122 "chi2": chi2,
123 }
124 )
126 if show:
127 plot_star_averager(avg_cat, res=res, dp=dp, goods=~solver.bads)
128 return avg_cat, index, ~solver.bads
131def add_columns(avg_cat, T, names, index, goods):
132 """Complete average catalog with columns"""
133 N = np.bincount(index)
134 d = {}
135 with np.errstate(divide="ignore", invalid="ignore"):
136 for n in names:
137 avg = np.bincount(index, T[n][goods]) / N
138 d[n] = avg
139 avg_cat = avg_cat.assign(**d)
140 return avg_cat
143def plot_night_averager(avg_cat, single=True, **kwargs):
144 """Show a comparison of residual dispersion and expected errors for
145 the night average fit, stacked over several nights.
146 """
147 if single:
148 return plot_single_night_averager(avg_cat, **kwargs)
149 fig, ax = plt.subplots(2, 2) # , sharex=True, sharey=True)
150 ax = list(ax.flatten())
151 mag = -2.5 * np.log10(avg_cat["flux"])
152 for x, ax0, ax1, xlabel in zip(
153 [avg_cat["mjd"].to_numpy(), mag.to_numpy()],
154 [ax[0], ax[1]],
155 [ax[2], ax[3]],
156 ["mjd", "mag"],
157 ):
158 binplot(
159 x,
160 (avg_cat["eflux"] / avg_cat["flux"]).to_numpy(),
161 label="predicted errors",
162 ax=ax0,
163 data=False,
164 )
166 binplot(
167 x,
168 (avg_cat["rms"] / avg_cat["flux"]).to_numpy(),
169 color="r",
170 label="res rms",
171 ax=ax0,
172 )
173 ax0.grid()
174 ax0.set_ylabel(r"$\sigma_f / f$")
175 ax0.legend()
176 binplot(x, avg_cat["chi2"].to_numpy(), ax=ax1)
177 ax1.set_xlabel(xlabel)
178 ax1.set_ylabel("chi2")
179 ax1.grid()
180 fig.tight_layout()
181 return fig, ax
184def plot_single_night_averager(avg_cat, res=None, dp=None):
185 """Averager control plots."""
186 N = 2 if res is None else 3
187 fig, ax = plt.subplots(N, 1, sharex=True)
188 ax[0].plot(avg_cat["mjd"], avg_cat["rms"] / avg_cat["flux"], "r.", label="res rms")
189 ax[0].plot(
190 avg_cat["mjd"],
191 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"],
192 "k.",
193 label=r"$\sigma_f \sqrt{N} / f$",
194 )
195 ax[0].grid()
196 ax[0].set_ylabel(r"$\sigma_f / f$")
197 ax[0].legend()
198 ax[1].plot(avg_cat["mjd"], avg_cat["chi2"], "k.")
199 ax[1].set_xlabel("mjd")
200 ax[1].set_ylabel("chi2")
201 ax[1].grid()
202 if res is not None:
203 binplot(dp.mjd, res, robust=True, ax=ax[2])
204 ax[2].plot(dp.mjd, res / dp.flux, "k.")
205 ax[2].set_ylabel("res / f")
206 ax[2].grid()
207 return fig, ax
210def plot_star_averager(avg_cat, res=None, dp=None, goods=None):
211 """Averager control plots."""
212 N = 2 if res is None else 3
213 fig, ax = plt.subplots(N, 1, sharex=True, layout="constrained")
215 m = -2.5 * np.log10(np.array(avg_cat["flux"]))
216 binplot(m, avg_cat["rms"] / avg_cat["flux"], color="r", label="res rms", ax=ax[0])
217 binplot(
218 m,
219 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"],
220 color="k",
221 label=r"$\sigma_f \sqrt{N}$",
222 ax=ax[0],
223 )
224 ax[0].grid()
225 ax[0].set_ylabel(r"$\sigma_f / f$")
226 ax[0].legend()
228 ax_histy = ax[1].inset_axes([1.05, 0, 0.25, 1], sharey=ax[1])
229 ok = np.isfinite(avg_cat["chi2"])
230 ax_histy.hist(
231 avg_cat["chi2"][ok],
232 bins=10,
233 density=True,
234 histtype="step",
235 color="black",
236 orientation="horizontal",
237 )
238 binplot(m[ok], np.array(avg_cat["chi2"])[ok], ax=ax[1])
239 ax[1].set_xlabel("mag")
240 ax[1].set_ylabel("chi2")
241 ax[1].grid()
243 if res is not None:
244 ax_histy = ax[2].inset_axes([1.05, 0, 0.25, 1], sharey=ax[2])
245 # x = -2.5 * np.log10(dp.flux)[goods]
246 y = res[goods]
247 res_min, res_max = -10000.0, 10000.0
248 xx = np.linspace(res_min, res_max, 1000)
249 me, sc = norm.fit(y)
250 ax_histy.tick_params(axis="y", labelleft=False)
251 ax_histy.hist(
252 y,
253 bins=50,
254 density=True,
255 histtype="step",
256 color="black",
257 orientation="horizontal",
258 )
259 ax_histy.plot(
260 norm.pdf(xx, loc=me, scale=sc), xx, color="black", label=f"{int(sc)}"
261 )
262 ax_histy.legend(fontsize=8)
264 binplot(-2.5 * np.log10(dp.flux), res, robust=True, ax=ax[2])
265 ax[2].set_ylabel("res")
266 ax[2].set_ylim(res_min, res_max)
267 ax[2].grid()
268 fig.tight_layout(h_pad=0.1)
269 return fig, ax
272def _deprecated_night_averager(
273 T, mjd_key="mjd", flux_key="flux", eflux_key="error", error_floor=0, show=False
274):
275 """Compute mean flux per night.
277 :param recarray T: input catalog
278 :return array avg_cat: mjd, flux, flux_err
279 :return array indices: mjd indices
280 """
281 # pylint: disable=E1101
282 # pylint: disable=E1130
284 dp = DataProxy(T, flux=flux_key, eflux=eflux_key)
285 dp.add_field("mjd", T[mjd_key].astype(int))
286 dp.make_index("mjd", intmap=True)
288 weights = 1.0 / np.sqrt(dp.eflux**2 + (dp.flux*error_floor)**2)
289 weights[~np.isfinite(weights)] = 0
291 model = LinearModel(
292 list(range(len(dp.nt))), dp.mjd_index, np.ones_like(dp.flux), name="avg_flux"
293 )
294 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights))
295 avg_flux = solver.robust_solution(nsig=3, local_param="avg_flux")
297 solver.model.params.free = avg_flux
298 res = solver.get_res(dp.flux)
299 wres = solver.get_wres(avg_flux)
300 index = dp.mjd_index[~solver.bads]
301 ngood = np.bincount(dp.mjd_index, ~solver.bads)
302 with np.errstate(divide="ignore", invalid="ignore"):
303 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood
304 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2
305 avg_cat = pd.DataFrame(
306 data={
307 "mjd": [
308 float(np.mean(T[mjd_key][dp.mjd_index == i]))
309 for i in range(len(dp.mjd_set))
310 ],
311 "flux": avg_flux,
312 "eflux": np.sqrt(solver.get_cov().diagonal()),
313 "rms": np.sqrt(mean_y2 - mean_2y),
314 "nmeas": np.bincount(index),
315 "chi2": np.bincount(index, weights=wres[~solver.bads] ** 2)
316 / (ngood - 1),
317 }
318 )
320 if show:
321 plot_single_night_averager(avg_cat, res=res, dp=dp) # , goods=~solver.bads)
323 return avg_cat, index, ~solver.bads
326def night_averager(T, error_floor=0, show=False, **kwargs):
327 """Compute mean flux per night.
329 :param recarray T: input catalog
330 :param float error_floor: constant extra error term, default 0
331 :param bool show: make and show control plots, default False
332 :param str mjd_key: mjd column name, default 'mjd'
333 :param str flux_key: flux column name, default 'flux'
334 :param str eflux_key: flux error column name, default 'fluxerr'
335 :param str list group_by: list of column name used to build the index, default=None
336 :param bool return_solver: return LinearModel solver instante, default False
337 :return array avg_cat: mjd, flux, flux_err
338 :return array indices: mjd indices
339 """
340 # pylint: disable=E1101
341 # pylint: disable=E1130
343 mjd_key = kwargs.get("mjd_key", "mjd")
344 flux_key = kwargs.get("flux_key", "flux")
345 eflux_key = kwargs.get("eflux_key", "fluxerr")
346 minweight = kwargs.get("minweight", 1e-8)
348 group_by = kwargs.get("group_by", []) # band, sn
349 return_solver = kwargs.get("return_solver", False)
351 # need to build an index for (int(mjd), band, sn)
352 T = pd.DataFrame(T)
353 dp = DataProxy(T, flux=flux_key, eflux=eflux_key)
355 group_by = group_by + [mjd_key]
356 u = T[group_by]
357 u.loc[:, mjd_key] = u[mjd_key].astype(int)
358 lc, _ = pd.factorize(pd.MultiIndex.from_frame(u))
359 dp.add_field("lc_night", lc)
360 dp.add_field("mjd", u[mjd_key])
361 dp.make_index("lc_night", intmap=True)
363 weights = 1.0 / np.sqrt(dp.eflux**2 + (dp.flux*error_floor)**2)
364 weights[~np.isfinite(weights)] = minweight
366 model = LinearModel(
367 list(range(len(dp.nt))),
368 dp.lc_night_index,
369 np.ones_like(dp.flux),
370 name="avg_flux",
371 )
372 solver = RobustLinearSolver(
373 model, np.array(dp.flux), weights=np.array(weights), verbose=1
374 )
375 avg_flux = solver.robust_solution(nsig=3, local_param="avg_flux")
376 solver.model.params.free = avg_flux
377 res = solver.get_res(dp.flux)
378 wres = solver.get_wres(avg_flux)
379 index = dp.lc_night_index[~solver.bads]
380 ngood = np.bincount(dp.lc_night_index, ~solver.bads)
382 logging.debug("building averaged catalog")
383 with np.errstate(divide="ignore", invalid="ignore"):
384 logging.debug("computing rms")
385 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood
386 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2
387 sig_fisher = 1.0 / np.sqrt((solver.A.T @ solver.A).diagonal())
389 logging.debug("building dataframe")
390 T = T.to_records(index=False)
391 avg_cat = pd.DataFrame(
392 data={
393 "mjd": np.bincount(index, weights=T[mjd_key][~solver.bads])
394 / np.bincount(index),
395 "flux": avg_flux,
396 "eflux": sig_fisher, # np.sqrt(solver.get_cov().diagonal()),
397 "rms": np.sqrt(mean_y2 - mean_2y),
398 "nmeas": np.bincount(index),
399 "chi2": np.bincount(index, weights=wres[~solver.bads] ** 2)
400 / (ngood - 1),
401 }
402 )
403 logging.debug("done")
404 if show:
405 plot_single_night_averager(avg_cat, res=res, dp=dp) # , goods=~solver.bads)
406 if return_solver:
407 return avg_cat, index, ~solver.bads, solver
408 return avg_cat, index, ~solver.bads
411def chain_averaging(lccat, logger, gaia_match=False, extra_cols=[], error_floor=0):
412 """Chain night and star averaging.
414 :param recarray lccat: star light curve catalog
415 :return recarray cat: night averaged catalog
416 :return recaraay star_lc_cat: star averaged catalog
417 """
418 if not isinstance(lccat, np.recarray):
419 lccat = lccat.to_records()
421 cat, idx, goods = night_averager(
422 lccat,
423 mjd_key="mjd",
424 flux_key="flux",
425 eflux_key="error",
426 show=False,
427 group_by=["star"],
428 error_floor=error_floor,
429 )
431 cat = add_columns(cat, lccat, ["ra", "dec", "star"] + extra_cols, idx, goods)
432 cat = pd.DataFrame(cat)
434 star_lc_cat, idx, goods = star_lc_averager(
435 cat.to_records(),
436 star_key="star",
437 flux_key="flux",
438 eflux_key="eflux",
439 show=False,
440 error_floor=error_floor,
441 )
442 star_lc_cat = add_columns(star_lc_cat, cat, ["ra", "dec"] + extra_cols, idx, goods)
444 if gaia_match:
445 selected_ids, star_lc_cat, _ = get_gaia_match(
446 star_lc_cat, offset=0.02, maxq=5000
447 )
448 star_lc_cat = star_lc_cat.astype({"star": "int64"})
449 star_lc_cat["star"] = selected_ids["source_id"].astype("int64").values
450 return cat, star_lc_cat
453# def old_chain_averaging(lccat, logger, gaia_match=False, extra_cols=[]):
454# """Chain night and star averaging.
456# :param recarray lccat: star light curve catalog
457# :return recarray cat: night averaged catalog
458# :return recaraay star_lc_cat: star averaged catalog
459# """
460# if not isinstance(lccat, np.recarray):
461# lccat = lccat.to_records()
462# star_set = list(set(lccat["star"]))
463# cat = []
464# for s in star_set:
465# T = lccat[lccat["star"] == s]
466# T = T[T["ra"] != 0] # todo remove this as soon as mklc is fixed
467# try:
468# cat_, idx, goods = night_averager(
469# T,
470# mjd_key="mjd",
471# flux_key="flux",
472# eflux_key="error",
473# show=False,
474# )
475# cat_ = cat_.assign(star=int(s) * np.ones((len(cat_))).astype(int))
476# cat_ = add_columns(cat_, T, ["ra", "dec"] + extra_cols, idx, goods)
477# cat.append(cat_)
478# except: # pylint: disable=bare-except
479# logger.warning(
480# f"Night averager failed for star {s}"
481# f", number of epochs is {len(set(T['mjd'].astype(int)))}"
482# )
483# cat_ = pd.DataFrame(
484# data={
485# "mjd": T["mjd"],
486# "flux": T["flux"],
487# "eflux": T["error"],
488# "star": int(s) * np.ones((len(T))).astype(int),
489# "rms": np.ones((len(T))) * np.nan,
490# "nmeas": np.ones((len(T))),
491# "chi2": np.ones((len(T))) * np.nan,
492# }
493# )
494# d = {}
495# for k in ["ra", "dec"] + extra_cols:
496# d[k] = T[k]
497# cat_ = cat_.assign(**d)
498# cat.append(cat_)
500# cat = pd.concat(cat)
501# cat = cat[np.isfinite(cat["eflux"])]
502# cat = cat[np.isfinite(cat["ra"])]
504# # cat = cut_from_epoch_number(cat.to_records(), min_d=min_d)
505# star_lc_cat, idx, goods = star_lc_averager(
506# cat.to_records(),
507# star_key="star",
508# flux_key="flux",
509# eflux_key="eflux",
510# show=False,
511# )
513# star_lc_cat = add_columns(star_lc_cat, cat, ["ra", "dec"] + extra_cols, idx, goods)
515# if gaia_match:
516# selected_ids, star_lc_cat, _ = get_gaia_match(
517# star_lc_cat, offset=0.02, maxq=5000
518# )
519# star_lc_cat = star_lc_cat.astype({"star": "int64"})
520# star_lc_cat["star"] = selected_ids["source_id"].astype("int64").values
521# return cat, star_lc_cat
524def lc_stack(
525 d_iterator, fn_provider, bands, cols=["flux", "eflux", "rms", "nmeas", "chi2"]
526):
527 """Stack all catalogs in a single one
529 :param dict d_iterator: keys and values indexing catalogs
530 :param func fn_provider: function which return a catalog filename for a given set of key/value
531 :param str list bands: list of band names
532 :param str list cols: list of columns to stack, named col_band in stacked catalog
533 :return recaraay stacked: stacked catalog
534 """
536 # pylint: disable=no-member
538 stacked = []
540 for k_i in itertools.product(*d_iterator.values()):
541 kwargs = dict(zip(d_iterator.keys(), k_i))
543 lc_catalog = []
544 for band in bands:
545 kwargs["band"] = band
546 cat_ = fn_provider(**kwargs)
547 if isinstance(cat_, str):
548 if os.path.exists(cat_):
549 cat_ = pd.read_parquet(cat_)
550 else:
551 continue
552 cat_ = cat_.assign(band=np.full(len(cat_), band))
553 lc_catalog.append(cat_)
554 if len(lc_catalog) == 0:
555 continue
556 lc_catalog = pd.concat(lc_catalog)
558 # flux per band as columns
559 dp = DataProxy(lc_catalog)
560 dp.add_field("star", lc_catalog["star"].astype(int))
561 dp.make_index("star") # , intmap=True)
562 reshaped = pd.DataFrame.from_dict({"star": dp.star_set})
564 dkey = {}
565 for _k, _v in kwargs.items():
566 if _k != "band":
567 dkey[_k] = [_v] * len(reshaped)
568 reshaped = reshaped.assign(**dkey)
570 reshaped = add_columns(
571 reshaped,
572 lc_catalog,
573 ["ra", "dec"],
574 dp.star_index,
575 np.ones((len(dp.star_index))).astype("bool"),
576 )
577 N = len(reshaped)
579 for band in bands:
580 selec = lc_catalog["band"] == band
581 nancols = dict(
582 zip(
583 [c + f"_{band}" for c in cols],
584 [np.ones((N)) * np.nan for i in cols],
585 )
586 )
587 reshaped = reshaped.assign(**nancols)
588 for l in cols:
589 k = l + f"_{band}"
590 i = reshaped.columns.get_loc(k)
591 reshaped.iloc[dp.star_index[selec], i] = lc_catalog[l][selec]
592 stacked.append(reshaped)
593 stacked = pd.concat(stacked)
594 return stacked