import numpy as np
import torch
from .dataset import light_curve_preprocess, fold_lightcurve, bin_timeseries
from .models import lc_component, ps_component, parameter_component, combined_net
from .config import PROPOSAL_NUM, LABELS
import matplotlib.pyplot as plt
import matplotlib
from .conformal import ConformalModelPrecomputed
[docs]class DeepLC:
"""Base class for light curve classification."""
[docs] def __init__(
self,
combined_model=None,
lc_component_model=None,
ps_component_model=None,
parameter_model=None,
conformal_calibration=False,
device="auto",
) -> None:
"""Initialize the classifier.
Parameters
----------
lc_component_model : str, optional
path to the light curve component model, by default None
ps_component_model : str, optional
path to the power spectrum component model, by default None
parameter_model : str, optional
path to the parameter model, by default None
combined_model : str, optional
path to the combined model, by default None
conformal_calibration : bool, optional
whether to use conformal calibration, by default False
device : str, optional
device of the model, by default 'auto'
"""
if device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
models = [
lc_component_model,
ps_component_model,
parameter_model,
combined_model,
]
num_models = sum(model is not None for model in models)
if num_models == 0:
raise ValueError("At least one model must be provided.")
elif num_models > 1:
raise ValueError("Only one model can be provided.")
if combined_model:
self.loaded_model = "Combined"
if isinstance(combined_model, str):
self.model_dict = torch.load(combined_model, map_location=self.device, weights_only=False)
elif isinstance(combined_model, dict):
self.model_dict = combined_model
self.nclasses = self.model_dict["num_classes"]
if conformal_calibration and "T" in self.model_dict.keys():
self.use_conformal_calibration = True
self.T = self.model_dict["T"]
self.Qhat = self.model_dict["Qhat"]
self.penalties = self.model_dict["penalties"]
else:
self.use_conformal_calibration = False
self.lc_model = lc_component(
topN=PROPOSAL_NUM, nclasses=self.nclasses, device=self.device
)
self.ps_model = ps_component(
topN=PROPOSAL_NUM, nclasses=self.nclasses, device=self.device
)
# TODO self.parameter_model = parameter_component()
self.model = combined_net(nclasses=self.nclasses)
self.lc_model.load_state_dict(self.model_dict["lc_net_state_dict"])
self.ps_model.load_state_dict(self.model_dict["ps_net_state_dict"])
self.model.load_state_dict(self.model_dict["net_state_dict"])
self.lc_model.to(self.device)
self.ps_model.to(self.device)
self.lc_model.eval()
self.ps_model.eval()
elif lc_component_model:
self.loaded_model = "LC Component"
if isinstance(lc_component_model, str):
self.model_dict = torch.load(lc_component_model, map_location=self.device, weights_only=False)
elif isinstance(lc_component_model, dict):
self.model_dict = lc_component_model
self.nclasses = self.model_dict["num_classes"]
if conformal_calibration and "T" in self.model_dict.keys():
self.use_conformal_calibration = True
self.T = self.model_dict["T"]
self.Qhat = self.model_dict["Qhat"]
self.penalties = self.model_dict["penalties"]
else:
self.use_conformal_calibration = False
self.model = lc_component(
topN=PROPOSAL_NUM, nclasses=self.nclasses, device=self.device
)
self.model.load_state_dict(self.model_dict["net_state_dict"])
elif ps_component_model:
self.loaded_model = "PS Component"
if isinstance(ps_component_model, str):
self.model_dict = torch.load(ps_component_model, map_location=self.device)
elif isinstance(ps_component_model, dict):
self.model_dict = ps_component_model
self.nclasses = self.model_dict["num_classes"]
if conformal_calibration and "T" in self.model_dict.keys():
self.use_conformal_calibration = True
self.T = self.model_dict["T"]
self.Qhat = self.model_dict["Qhat"]
self.penalties = self.model_dict["penalties"]
else:
self.use_conformal_calibration = False
self.model = ps_component(
topN=PROPOSAL_NUM, nclasses=self.nclasses, device=self.device
)
self.model.load_state_dict(self.model_dict["net_state_dict"])
elif parameter_model:
self.loaded_model = "Parameter Component"
if isinstance(parameter_model, str):
self.model_dict = torch.load(parameter_model)
elif isinstance(parameter_model, dict):
self.model_dict = parameter_model
self.nclasses = self.model_dict["num_classes"]
if conformal_calibration and "T" in self.model_dict.keys():
self.use_conformal_calibration = True
self.T = self.model_dict["T"]
self.Qhat = self.model_dict["Qhat"]
self.penalties = self.model_dict["penalties"]
else:
self.use_conformal_calibration = False
self.model = parameter_component(nclasses=self.nclasses)
self.model.load_state_dict(self.model_dict["net_state_dict"])
self.model.to(self.device)
self.model.eval()
[docs] def predict(
self,
light_curve,
multiband_FAP=False,
show_intermediate_results=False,
return_intermediate_data=False,
return_conformal_predictive_sets=False,
return_ood_criteria=False,
):
"""Classify the light curve data.
Parameters
----------
light_curve : (N, 2) array for time and flux,
or (N, 3) array for time, flux, and filter,
or (N, 4) array for time, flux, flux_error and filter
show_intermediate_results : bool, optional
whether to show intermediate results, by default False
return_intermediate_data : bool, optional
whether to return intermediate data, by default False
return_conformal_predictive_sets : bool, optional
whether to return conformal predictive sets, by default False
"""
(
lc_img,
ps_img,
folded_img,
lc_param,
ps_param,
lc_data,
ps_data,
) = light_curve_preprocess(light_curve, multiband_FAP=multiband_FAP)
# move them to the corresponding device
lc_img = lc_img.unsqueeze(0).to(self.device)
ps_img = ps_img.unsqueeze(0).to(self.device)
folded_img = folded_img.unsqueeze(0).to(self.device)
lc_param = lc_param.unsqueeze(0).to(self.device)
ps_param = ps_param.unsqueeze(0).to(self.device)
lc_data = [lc_data]
ps_data = [ps_data]
if self.loaded_model == "Combined":
if self.use_conformal_calibration and return_conformal_predictive_sets:
cmodel = ConformalModelPrecomputed(
self.model,
self.T,
self.Qhat,
self.penalties,
self.nclasses,
allow_zero_sets=True,
)
cmodel.to(self.device)
cmodel.eval()
else:
cmodel = self.model
if show_intermediate_results or return_intermediate_data:
(
lc_concat_out,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
) = self.lc_model(
lc_img,
None,
None,
lc_param,
None,
lc_data,
None,
return_part_data=True,
)
(
ps_concat_out,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
) = self.ps_model(
None,
ps_img,
folded_img,
None,
ps_param,
lc_data.copy(),
ps_data,
return_part_data=True,
)
if return_conformal_predictive_sets:
concat_logits, sets = cmodel(lc_concat_out, ps_concat_out)
predicted_label = [LABELS[s] for s in sets[0]]
else:
concat_logits = cmodel(lc_concat_out, ps_concat_out)
predicted_label = LABELS[torch.argmax(concat_logits, 1)]
if return_intermediate_data:
return predicted_label, (
lc_raw_logits.cpu().detach().numpy(),
lc_concat_logits.cpu().detach().numpy(),
lc_part_logits.cpu().detach().numpy(),
part_lc_list,
ps_raw_logits.cpu().detach().numpy(),
ps_concat_logits.cpu().detach().numpy(),
ps_part_logits.cpu().detach().numpy(),
part_ps_list,
)
elif not return_intermediate_data and show_intermediate_results:
figs = self.plot_intermediate_data(
(
ps_param.cpu().detach().numpy(),
lc_data,
ps_data,
lc_raw_logits.cpu().detach().numpy(),
lc_concat_logits.cpu().detach().numpy(),
lc_part_logits.cpu().detach().numpy(),
part_lc_list,
part_lc_params.cpu().detach().numpy(),
ps_raw_logits.cpu().detach().numpy(),
ps_concat_logits.cpu().detach().numpy(),
ps_part_logits.cpu().detach().numpy(),
part_ps_list,
part_ps_params.cpu().detach().numpy(),
)
)
return predicted_label, figs
figs = self.plot_intermediate_data(
(
ps_param.cpu().detach().numpy(),
lc_data,
ps_data,
lc_raw_logits.cpu().detach().numpy(),
lc_concat_logits.cpu().detach().numpy(),
lc_part_logits.cpu().detach().numpy(),
part_lc_list,
part_lc_params.cpu().detach().numpy(),
ps_raw_logits.cpu().detach().numpy(),
ps_concat_logits.cpu().detach().numpy(),
ps_part_logits.cpu().detach().numpy(),
part_ps_list,
part_ps_params.cpu().detach().numpy(),
)
)
return predicted_label, (
lc_raw_logits.cpu().detach().numpy(),
lc_concat_logits.cpu().detach().numpy(),
lc_part_logits.cpu().detach().numpy(),
part_lc_list,
ps_raw_logits.cpu().detach().numpy(),
ps_concat_logits.cpu().detach().numpy(),
ps_part_logits.cpu().detach().numpy(),
part_ps_list,
), figs
else:
lc_concat_out = self.lc_model(
lc_img,
None,
None,
lc_param,
None,
lc_data,
None,
return_part_data=False,
combined_mode=True,
)
ps_concat_out = self.ps_model(
None,
ps_img,
folded_img,
None,
ps_param,
lc_data,
ps_data,
return_part_data=False,
combined_mode=True,
)
if return_conformal_predictive_sets:
concat_logits, sets = cmodel(lc_concat_out, ps_concat_out)
predicted_label = [LABELS[s] for s in sets[0]]
else:
concat_logits = cmodel(lc_concat_out, ps_concat_out)
predicted_label = LABELS[torch.argmax(concat_logits, 1)]
return predicted_label
elif self.loaded_model == "LC Component":
if show_intermediate_results or return_intermediate_data:
(
lc_concat_out,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
) = self.model(
lc_img,
None,
None,
lc_param,
None,
lc_data,
None,
return_part_data=True,
)
predicted_label = LABELS[torch.argmax(lc_concat_logits, 1)]
if return_intermediate_data:
return predicted_label, (
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
)
fig = self.plot_intermediate_data(
(
lc_data,
lc_raw_logits.cpu().detach().numpy(),
lc_concat_logits.cpu().detach().numpy(),
lc_part_logits.cpu().detach().numpy(),
part_lc_list,
part_lc_params.cpu().detach().numpy(),
)
)
return predicted_label, fig
else:
(
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
) = self.model(lc_img, None, None, lc_param, None, lc_data, None)
predicted_label = LABELS[torch.argmax(lc_concat_logits, 1)]
return predicted_label
elif self.loaded_model == "PS Component":
if show_intermediate_results or return_intermediate_data:
(
ps_concat_out,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
) = self.model(
None,
ps_img,
folded_img,
None,
ps_param,
lc_data,
ps_data,
return_part_data=True,
)
predicted_label = LABELS[torch.argmax(ps_concat_logits, 1)]
if return_intermediate_data:
return predicted_label, (
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
)
fig = self.plot_intermediate_data(
(
lc_data,
ps_param.cpu(),
ps_data,
ps_raw_logits.cpu().detach().numpy(),
ps_concat_logits.cpu().detach().numpy(),
ps_part_logits.cpu().detach().numpy(),
part_ps_list,
part_ps_params.cpu().detach().numpy(),
)
)
return predicted_label, fig
else:
(
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
) = self.model(
None,
ps_img,
folded_img,
None,
ps_param,
lc_data,
ps_data,
)
predicted_label = LABELS[torch.argmax(ps_concat_logits, 1)]
return predicted_label
def plot_intermediate_data(self, intermediate_data):
if self.loaded_model == "Combined":
(
ps_param,
lc_data,
ps_data,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
) = intermediate_data
fig1 = plot_lc_component(
lc_data,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
)
fig2 = plot_ps_component(
lc_data,
ps_param,
ps_data,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
)
return fig1, fig2
elif self.loaded_model == "LC Component":
(
lc_data,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
) = intermediate_data
fig = plot_lc_component(
lc_data,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
)
elif self.loaded_model == "PS Component":
(
lc_data,
ps_param,
ps_data,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
) = intermediate_data
fig = plot_ps_component(
lc_data,
ps_param,
ps_data,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
)
def plot_lc_component(
lc_data,
lc_raw_logits,
lc_concat_logits,
lc_part_logits,
part_lc_list,
part_lc_params,
):
predicted_raw_lable = LABELS[np.argmax(lc_raw_logits)]
predicted_label = LABELS[np.argmax(lc_concat_logits)]
predicted_part_labels = [
LABELS[i] for i in np.argmax(lc_part_logits, axis=2).squeeze()
]
lc_data = lc_data[0]
lc_mask = np.all(part_lc_params != 0, axis=1)
sub_lc_num = sum(lc_mask)
indices = np.where(lc_mask)[0]
selected_lc_list = [part_lc_list[i] for i in indices]
selected_lc_labels = [predicted_part_labels[i] for i in indices]
ratio_list = [1] * (int((sub_lc_num + 1) / 2) + 1)
ratio_list[0] = 2
fig1, ax = plt.subplots(
int((sub_lc_num + 1) / 2) + 1,
2,
figsize=(6, (sub_lc_num + 1) + 2),
dpi=90,
constrained_layout=True,
gridspec_kw={"height_ratios": ratio_list},
)
if sub_lc_num == 0:
ax = ax.reshape(-1, 2)
gs1 = ax[0, 1].get_gridspec()
for a in ax[0, :]:
a.remove()
ax_lc = fig1.add_subplot(gs1[0, :])
colors = [matplotlib.colormaps["Dark2"](i / 6) for i in range(6)]
# if lc_panel_num is odd number, add a blank panel
if sub_lc_num % 2 == 1:
ax[-1, 0].remove()
ax[-1, 1].remove()
gs2 = ax[-1, 1].get_gridspec()
ax[-1, 0] = fig1.add_subplot(gs2[-1, :])
if lc_data.shape[1] == 2:
ax_lc.plot(
lc_data[:, 0],
lc_data[:, 1],
"k.",
ms=np.round(10 / np.log10(len(lc_data[:, 0])+10), 1),
)
ax_lc.set_title(f"{predicted_label} ({predicted_raw_lable})")
# plot vspans for selected light curves
for i, lc in enumerate(selected_lc_list):
ax_lc.axvspan(lc[0, 0], lc[-1, 0], color=colors[i], alpha=0.3)
# plot selected light curves
ax[i // 2 + 1, i % 2].plot(
lc[:, 0],
lc[:, 1],
".",
color=colors[i],
ms=np.round(10 / np.log10(len(lc[:, 0])+10), 1),
)
# show parameters
ax[i // 2 + 1, i % 2].set_title(f"{selected_lc_labels[i]}")
else:
bands = np.unique(lc_data[:, 3])
# currently we only support 2 bands
multiband_colors = [
matplotlib.colormaps["binary"](i / (2 + 1))
for i in range(2 + 1)
]
for band in bands:
band_mask = lc_data[:, 3] == band
ax_lc.plot(
lc_data[band_mask, 0],
lc_data[band_mask, 1],
".",
ms=np.round(10 / np.log10(len(lc_data[band_mask, 0])+10), 1),
# color is a grey scale for different bands
color=multiband_colors[int(band)],
)
ax_lc.set_title(f"{predicted_label} ({predicted_raw_lable})")
# plot vspans for selected light curves
for i, lc in enumerate(selected_lc_list):
ax_lc.axvspan(lc[:, 0].min(), lc[:, 0].max(), color=colors[i], alpha=0.3)
# plot selected light curves
for band in bands:
band_mask = lc[:, 3] == band
ax[i // 2 + 1, i % 2].plot(
lc[band_mask, 0],
lc[band_mask, 1],
".",
color=multiband_colors[int(band)],
ms=np.round(10 / np.log10(len(lc[band_mask, 0])+10), 1),
)
# show parameters
ax[i // 2 + 1, i % 2].set_title(f"{selected_lc_labels[i]}", color=colors[i])
# all the y axis are scientific notation
ax_lc.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
for a in ax.flatten():
a.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
# add labels
ax_lc.set_xlabel("Time (days)")
ax_lc.set_ylabel("Variability")
return fig1
def plot_ps_component(
lc_data,
ps_param,
ps_data,
ps_raw_logits,
ps_concat_logits,
ps_part_logits,
part_ps_list,
part_ps_params,
):
predicted_raw_lable = LABELS[np.argmax(ps_raw_logits)]
predicted_label = LABELS[np.argmax(ps_concat_logits)]
predicted_part_labels = [
LABELS[i] for i in np.argmax(ps_part_logits, axis=2).squeeze()
]
ps_data = ps_data[0]
lc_data = lc_data[0]
period = ps_param[0, 1]
lg_fap100 = ps_param[0, 2]
phase_list, flux_list = fold_lightcurve(lc_data, period)
ps_mask = np.all(part_ps_params != 0, axis=1)
ps_panel_num = sum(ps_mask)
indices = np.where(ps_mask)[0]
selected_ps_list = [part_ps_list[i] for i in indices]
selected_ps_params = [part_ps_params[i] for i in indices]
selected_ps_labels = [predicted_part_labels[i] for i in indices]
fig2, ax = plt.subplots(
ps_panel_num + 1,
2,
figsize=(6, (ps_panel_num + 1) * 1.5 + 2),
dpi=90,
constrained_layout=True,
)
if ps_panel_num == 0:
ax = ax.reshape(-1, 2)
colors = [matplotlib.colormaps["Dark2"](i / 6) for i in range(6)]
if lc_data.shape[1] > 2:
# bands = np.unique(lc_data[:, 3])
multiband_colors = [
matplotlib.colormaps["binary"](i / (2 + 1))
for i in range(2 + 1)
]
else:
# bands = [0]
multiband_colors = ["k"] * 2
ax[0, 0].plot(ps_data[:, 0], ps_data[:, 1], "k-", lw=1)
if lg_fap100 != 0:
ax[0, 0].axhline(10**lg_fap100, color="r", ls="--", lw=1)
ax[0, 0].set_yscale('log')
ax[0, 0].set_ylabel("Amp")
ax[0, 0].set_title(f"{predicted_label}")
for band, (phase, flux) in enumerate(zip(phase_list, flux_list)):
if len(phase) == 0:
continue
new_phase, new_flux = bin_timeseries(phase, flux, 512)
ax[0, 1].plot(
new_phase,
new_flux,
".",
ms=np.round(10 / np.log10(len(new_phase)+10), 1),
color=multiband_colors[int(band) + 1],
)
ax[0, 1].set_title(f"{predicted_raw_lable} ({period:.2f} days)", loc="right", pad=0)
# plot vspans for selected light curves
for i, ps in enumerate(selected_ps_list):
ax[0, 0].axvspan(ps[0, 0], ps[-1, 0], color=colors[i], alpha=0.5)
# plot selected light curves
ax[i + 1, 0].plot(ps[:, 0], ps[:, 1], "-", color=colors[i], lw=1)
ax[i + 1, 0].set_yscale('log')
ax[i + 1, 0].set_ylabel("Amp")
# set x tick labels to only the start point and the end point
ax[i + 1, 0].set_xticks([ps[0, 0], ps[-1, 0]])
ax[i + 1, 0].xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
period = selected_ps_params[i][1]
phase_list, flux_list = fold_lightcurve(lc_data, period)
for band, (phase, flux) in enumerate(zip(phase_list, flux_list)):
if len(phase) == 0:
continue
new_phase, new_flux = bin_timeseries(phase, flux, 512)
ax[i + 1, 1].plot(
new_phase,
new_flux,
".",
ms=np.round(10 / np.log10(len(new_phase)+10), 1),
color=multiband_colors[int(band) + 1],
)
# show parameters
ax[i + 1, 1].set_title(
f"{selected_ps_labels[i]} ({period:.2f} days)", loc="right", pad=0
)
# all the y-axis are scientific notation
for a in ax[:, 1]:
a.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
for a in ax[:, 0]:
a.minorticks_off()
# add lables for the last pannel
ax[-1, 0].set_xlabel("Frequency (day$^{-1}$)")
ax[-1, 1].set_xlabel("Phase")
return fig2