import os
import pathlib
from multiprocessing import Pool
from time import time
from urllib.error import HTTPError
import astropy.units as u
import numpy as np
import pandas as pd
import requests
from astropy.coordinates import SkyCoord
from astropy.cosmology import LambdaCDM
from .diagnose import plot_match
from .helpers import GalaxyCatalog, Transient
[docs]
def associate_transient(
idx,
row,
glade_catalog,
n_samples,
verbose,
priorfunc_z,
priorfunc_offset,
priorfunc_absmag,
likefunc_offset,
likefunc_absmag,
cosmo,
catalogs,
cat_cols,
):
"""Short summary.
Parameters
----------
idx : type
Description of parameter `idx`.
row : type
Description of parameter `row`.
glade_catalog : type
Description of parameter `glade_catalog`.
n_samples : type
Description of parameter `n_samples`.
verbose : type
Description of parameter `verbose`.
priorfunc_z : type
Description of parameter `priorfunc_z`.
priorfunc_offset : type
Description of parameter `priorfunc_offset`.
priorfunc_absmag : type
Description of parameter `priorfunc_absmag`.
likefunc_offset : type
Description of parameter `likefunc_offset`.
likefunc_absmag : type
Description of parameter `likefunc_absmag`.
cosmo : type
Description of parameter `cosmo`.
catalogs : type
Description of parameter `catalogs`.
cat_cols : type
Description of parameter `cat_cols`.
Returns
-------
type
Description of returned object.
"""
try:
transient = Transient(
name=row["name"],
position=SkyCoord(row.ra * u.deg, row.dec * u.deg),
redshift=float(row.redshift),
n_samples=n_samples,
)
except KeyError:
transient = Transient(
name=row["name"], position=SkyCoord(row.ra * u.deg, row.dec * u.deg), n_samples=n_samples
)
if verbose > 0:
print(
f"Associating for {transient.name} at RA, DEC = "
"{transient.position.ra.deg:.6f}, {transient.position.dec.deg:.6f}"
)
transient.set_prior("redshift", priorfunc_z)
transient.set_prior("offset", priorfunc_offset)
transient.set_prior("absmag", priorfunc_absmag)
transient.set_likelihood("offset", likefunc_offset)
transient.set_likelihood("absmag", likefunc_absmag)
best_prob, best_ra, best_dec, query_time = (
np.nan,
np.nan,
np.nan,
np.nan,
) # Default values when no good host is found
best_cat = ""
for cat_name in catalogs:
cat = GalaxyCatalog(name=cat_name, n_samples=n_samples, data=glade_catalog)
try:
cat.get_candidates(transient, timequery=True, verbose=verbose, cosmo=cosmo)
except requests.exceptions.HTTPError:
print(f"Candidate retrieval failed for {transient.name} in catalog {cat_name}.")
continue
if cat.ngals > 0:
cat = transient.associate(cat, cosmo, verbose=verbose)
if transient.best_host != -1:
best_idx = transient.best_host
second_best_idx = transient.second_best_host
if verbose >= 2:
print_cols = [
"objID",
"z_prob",
"offset_prob",
"absmag_prob",
"total_prob",
"ra",
"dec",
"offset_arcsec",
"z_best_mean",
"z_best_std",
]
print("Properties of best host:")
for key in print_cols:
print(key)
print(cat.galaxies[key][best_idx])
print("Properties of second best host:")
for key in print_cols:
print(key)
print(cat.galaxies[key][second_best_idx])
best_objid = cat.galaxies["objID"][best_idx]
best_prob = cat.galaxies["total_prob"][best_idx]
best_ra = cat.galaxies["ra"][best_idx]
best_dec = cat.galaxies["dec"][best_idx]
second_best_objid = cat.galaxies["objID"][second_best_idx]
second_best_prob = cat.galaxies["total_prob"][second_best_idx]
second_best_ra = cat.galaxies["ra"][second_best_idx]
second_best_dec = cat.galaxies["dec"][second_best_idx]
best_cat = cat_name
query_time = cat.query_time
if cat_cols:
print("WARNING! cat_cols not implemented yet.")
if verbose > 0:
print(f"Found a good host in {cat_name}!")
print(
f"Chosen galaxy has catalog ID of {best_objid}"
"and RA, DEC = {best_ra:.6f}, {best_dec:.6f}"
)
try:
plot_match(
[best_ra],
[best_dec],
None,
None,
cat.galaxies["z_best_mean"][best_idx],
cat.galaxies["z_best_std"][best_idx],
transient.position.ra.deg,
transient.position.dec.deg,
transient.name,
transient.redshift,
0,
f"{transient.name}_{cat_name}",
)
except HTTPError:
print("Couldn't get an image. Waiting 60s before moving on.")
time.sleep(60)
continue
if (transient.best_host == -1) and (verbose > 0):
print("No good host found!")
return (
idx,
best_objid,
best_prob,
best_ra,
best_dec,
second_best_objid,
second_best_prob,
second_best_ra,
second_best_dec,
query_time,
best_cat,
)
[docs]
def prepare_catalog(
transient_catalog,
debug_names,
transient_name_col="name",
transient_coord_cols=("ra", "dec"),
debug=False,
):
"""Short summary.
Parameters
----------
transient_catalog : type
Description of parameter `transient_catalog`.
transient_name_col : type
Description of parameter `transient_name_col`.
transient_coord_cols : type
Description of parameter `transient_coord_cols`.
debug_names : type
Description of parameter `debug_names`.
debug : type
Description of parameter `debug`.
Returns
-------
type
Description of returned object.
"""
association_fields = [
"host_id",
"host_ra",
"host_dec",
"host_prob",
"host_2_id",
"host_2_ra",
"host_2_dec",
"host_2_prob",
"smallcone_prob",
"missedcat_prob",
"sn_ra_deg",
"sn_dec_deg",
"prob_association_time",
]
for field in association_fields:
transient_catalog[field] = np.nan
transient_catalog["prob_host_flag"] = 0
# debugging with just the ones we got wrong
if debug and len(debug_names) > 0:
transient_catalog = transient_catalog[transient_catalog[transient_name_col].isin(debug_names)]
# convert coords if needed
if ":" in str(transient_catalog[transient_coord_cols[0]].values[0]):
ra = transient_catalog[transient_coord_cols[0]].values
dec = transient_catalog[transient_coord_cols[1]].values
transient_coords = SkyCoord(ra, dec, unit=(u.hourangle, u.deg))
else:
# try parsing as a float
try:
ra = transient_catalog[transient_coord_cols[0]].astype("float").values
dec = transient_catalog[transient_coord_cols[1]].astype("float").values
transient_coords = SkyCoord(ra, dec, unit=(u.deg, u.deg))
except KeyError as err:
raise ValueError("ERROR: I could not understand your provided coordinates.") from err
transient_catalog["ra"] = transient_coords.ra.deg
transient_catalog["dec"] = transient_coords.dec.deg
transient_catalog.rename(columns={transient_name_col: "name"}, inplace=True)
# randomly shuffle
transient_catalog = transient_catalog.sample(frac=1).reset_index(drop=True)
transient_catalog.reset_index(inplace=True, drop=True)
return transient_catalog
[docs]
def associate_sample(
transient_catalog,
catalogs,
priors=None,
likes=None,
n_samples=1000,
verbose=False,
parallel=True,
save=True,
save_path="./",
cat_cols=False,
cosmology=None,
):
"""Short summary.
Parameters
----------
transient_catalog : type
Description of parameter `transient_catalog`.
priors : type
Description of parameter `priors`.
likes : type
Description of parameter `likes`.
catalogs : type
Description of parameter `catalogs`.
n_samples : type
Description of parameter `n_samples`.
verbose : type
Description of parameter `verbose`.
parallel : type
Description of parameter `parallel`.
save : type
Description of parameter `save`.
save_path : type
Description of parameter `save_path`.
cat_cols : type
Description of parameter `cat_cols`.
cosmology : type
Description of parameter `cosmology`.
Returns
-------
type
Description of returned object.
"""
if not cosmology:
cosmo = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7)
for key in ["offset", "absmag", "z"]:
if key not in priors:
raise ValueError(f"ERROR: Please set a prior function for {key}.")
elif (key not in likes) and (key != "z"):
raise ValueError(f"ERROR: Please set a likelihood function for {key}.")
# always load GLADE -- we now use it for spec-zs.
try:
glade_catalog = pd.read_csv("GLADE+_HyperLedaSizes_mod_withz.csv")
except FileNotFoundError:
glade_catalog = None
# unpack priors and likelihoods
priorfunc_z = priors["z"]
priorfunc_offset = priors["offset"]
priorfunc_absmag = priors["absmag"]
likefunc_offset = likes["offset"]
likefunc_absmag = likes["absmag"]
if parallel:
n_processes = os.cpu_count() - 5
# Create a list of tasks (one per transient)
print("parallelizing...")
tasks = [
(
idx,
row,
glade_catalog,
n_samples,
verbose,
priorfunc_z,
priorfunc_offset,
priorfunc_absmag,
likefunc_offset,
likefunc_absmag,
cosmo,
catalogs,
cat_cols,
)
for idx, row in transient_catalog.iterrows()
]
# Run the association tasks in parallel
with Pool(processes=n_processes) as pool:
results = pool.starmap(associate_transient, tasks)
pool.close()
pool.join() # Ensures that all resources are released
else:
results = []
for idx, row in transient_catalog.iterrows():
event = (
idx,
row,
glade_catalog,
n_samples,
verbose,
priorfunc_z,
priorfunc_offset,
priorfunc_absmag,
likefunc_offset,
likefunc_absmag,
cosmo,
catalogs,
cat_cols,
)
results.append(associate_transient(*event))
# Update transient_catalog with results
for result in results:
idx, best_prob, best_ra, best_dec, query_time, best_cat = result
transient_catalog.at[idx, "prob_host_ra"] = best_ra
transient_catalog.at[idx, "prob_host_dec"] = best_dec
transient_catalog.at[idx, "prob_host_score"] = best_prob
transient_catalog.at[idx, "prob_query_time"] = query_time
transient_catalog.at[idx, "prob_best_cat"] = best_cat
print("Association of all transients is complete.")
# Save the updated catalog
if save:
ts = int(time.time())
save_name = pathlib.Path(save_path, f"associated_transient_catalog_{ts}.csv")
transient_catalog.to_csv(save_name, index=False)
else:
return transient_catalog