#! /usr/bin/env python
"""
Module containing basic classes for manipulating post-processing algorithms.
This includes the core PostProc class, parent to every algorithm object implementation,
but also the PPResult class, a container for the results obtained through those said
algorithm objects. PPResult is provided with the Session dataclass, which defines the
type of data stored from the results.
"""
__author__ = "Thomas Bédrine, Carlos Alberto Gomez Gonzalez, Ralf Farkas"
__all__ = ["PostProc", "PPResult", "ALL_SESSIONS", "LAST_SESSION"]
import pickle
from dataclasses import dataclass, field
from typing import (
Tuple,
Union,
Optional,
NoReturn,
Callable,
List,
)
import numpy as np
from hciplot import plot_frames
from sklearn.base import BaseEstimator
from .dataset import Dataset
from ..config.paramenum import ALL_FITS
from ..config.utils_conf import algo_calculates_decorator as calculates
from ..config.utils_param import print_algo_params
from ..fits import write_fits, open_fits, dict_to_fitsheader, fitsheader_to_dict
from ..metrics import snrmap, snr, significance
from ..var import frame_center
PROBLEMATIC_ATTRIBUTE_NAMES = ["_repr_html_", "_estimator_html_repr",
"_doc_link_template"]
LAST_SESSION = -1
ALL_SESSIONS = -2
DATASET_PARAM = "dataset"
EXPLICIT_PARAMS = {
"cube": "cube",
"angle_list": "angles",
"fwhm": "fwhm",
"cube_ref": "cuberef",
"scale_list": "wavelengths",
"psf": "psfn",
}
PREFIX = "postproc_"
@dataclass
class Session:
"""
Dataclass for post-processing information storage.
Each session of post-processing with one of the PostProc objects has a defined set
of parameters, a frame obtained with those parameters and a S/N map generated with
that frame. The Session class holds them in case you need to access them later or
compare with another session.
"""
parameters: dict
frame: np.ndarray
snr_map: np.ndarray
algo_name: str
# TODO: find a proper format for results saving (pdf, images, dictionnaries...)
[docs]
@dataclass
class PPResult:
"""
Container for results of post-processing algorithms.
For each given set of data and parameters, a frame is computed by the PostProc
algorithms, as well as a S/N map associated. To keep track of each of them, this
object remembers each set of parameters, frame and S/N map as a session. Sessions
are numbered in order of creation from 0 to X, and they are displayed to the user
as going from 1 to X+1.
"""
sessions: List = field(default_factory=lambda: [])
def __init__(self, load_from_path: str = None):
"""
Create a PPResult object or load one from a FITS file.
Parameters
----------
load_from_path : str, optional
Path of FITS file to optionally load a previously saved PPResult
object from.
"""
self.sessions = []
if load_from_path is not None:
self.fits_to_results(filepath=load_from_path)
[docs]
def register_session(
self,
frame: np.ndarray,
algo_name: Optional[str] = None,
params: Optional[dict] = None,
snr_map: Optional[np.ndarray] = None,
) -> None:
"""
Register data for a new session or updating data for an existing one.
Parameters
----------
frame : np.ndarray
Frame obtained after an iteration of a PostProc object.
params : dictionnary, optional
Set of parameters used for an iteration of a PostProc object.
snr_map : np.ndarray, optional
Signal-to-noise ratio map generated through the ``make_snrmap`` method of
PostProc. Usually given after generating ``frame``.
"""
# If frame is already registered in a session, add the associated snr_map only
for session in self.sessions:
if session.frame.shape == frame.shape:
if (
np.allclose(np.abs(session.frame), np.abs(frame), atol=1e-3)
and snr_map is not None
):
session.snr_map = snr_map
return
# TODO: review filter_params to only target cube and angles, not all ndarrays
# TODO: rename angles-type parameters in all procedural functions
# Otherwise, register a new session
filter_params = {
key: params[key]
for key in params
if not isinstance(params[key], np.ndarray)
}
new_session = Session(
parameters=filter_params,
frame=frame,
snr_map=snr_map,
algo_name=algo_name,
)
self.sessions.append(new_session)
[docs]
def show_session_results(
self,
session_id: Optional[int] = LAST_SESSION,
label: Optional[Union[Tuple[str], bool]] = True,
) -> None:
"""
Print the parameters and plot the frame (and S/N map if able) of a session(s).
Parameters
----------
session_id : int, list of int or str, optional
The ID of the session(s) to show. It is possible to get several sessions
results by giving a list of int or "all" to get all of them. By default,
the last session is displayed (index -1).
"""
if self.sessions:
if isinstance(session_id, list):
if all(isinstance(s_id, int) for s_id in session_id):
for s_id in session_id:
self._show_single_session(s_id, label)
elif session_id == ALL_SESSIONS:
for s_id, _ in enumerate(self.sessions):
self._show_single_session(s_id, label)
elif session_id in range(ALL_SESSIONS + 1, len(self.sessions)):
self._show_single_session(session_id, label)
else:
raise ValueError(
"Given session ID isn't an integer. Please give an integer or a"
"list of integers (includes constant values such as ALL_SESSIONS or"
" LAST_SESSION)."
)
else:
raise AttributeError(
"No session was registered yet. Please register"
" a session with the function `register_session`."
)
[docs]
def results_to_fits(self, filepath: str) -> None:
"""
Save all configurations as a fits file.
Parameters
----------
filepath: str
The path of the FITS file.
"""
if self.sessions:
images = []
headers = []
for _, session in enumerate(self.sessions):
cube = None
# Stacks both frame and detection map (if any), else only frame
if session.snr_map is not None:
cube = np.stack((session.frame, session.snr_map), axis=0)
else:
cube = session.frame
images.append(cube)
session.parameters["algo_name"] = session.algo_name
# Adding a specific prefix to identify the PostProc parameters when
# extracting the header
prefix_dict = {
PREFIX + key: value for key, value in session.parameters.items()
}
fits_header = dict_to_fitsheader(prefix_dict)
headers.append(fits_header)
write_fits(
fitsfilename=filepath, array=tuple(images), header=tuple(headers)
)
print(f"Results saved successfully to {filepath} !")
else:
raise AttributeError(
"No session was registered yet. Please register"
" a session with the function `register_session`."
)
[docs]
def fits_to_results(self, filepath: str, session_id: int = ALL_FITS) -> None:
"""
Load all configurations from a fits file.
Parameters
----------
filepath: str
The path of the FITS file.
"""
data, header = open_fits(fitsfilename=filepath, n=session_id,
header=True)
self.sessions = []
if session_id == ALL_FITS:
for index, element in enumerate(data):
frame = None
snr_map = None
parameters, algo_name = fitsheader_to_dict(
initial_header=header[index], sort_by_prefix=PREFIX
)
# Both frame and detmap were saved
if element.ndim == 3:
frame = element[0]
snr_map = element[1]
# Frame only
else:
frame = element
self.register_session(
frame=frame, algo_name=algo_name, params=parameters, snr_map=snr_map
)
else:
frame = None
snr_map = None
parameters, algo_name = fitsheader_to_dict(
initial_header=header, sort_by_prefix=PREFIX
)
# Both frame and detmap were saved
if data.ndim == 3:
frame = data[0]
snr_map = data[1]
# Frame only
else:
frame = data
self.register_session(
frame=frame, algo_name=algo_name, params=parameters, snr_map=snr_map
)
def _show_single_session(
self,
session_id: Optional[int],
label: Optional[Union[Tuple[str], bool]] = True,
) -> None:
"""
Display an individual session.
Used a sub function to be called by ``show_session_results``.
Parameters
----------
session_id : int, optional
Number of the session to be displayed.
label : tuple of str or bool, optional
Defines the label given to the frames plotted. If True, prints the default
label for each frame, if False, prints nothing. Instead if the label is a
tuple of str, sets them as the label for each frame.
"""
if session_id == LAST_SESSION:
session_label = "last session"
else:
session_label = "session n°" + str(session_id + 1)
print(
"Parameters used for the",
session_label,
f"(function used : {self.sessions[session_id].algo_name}) : ",
)
print_algo_params(self.sessions[session_id].parameters)
if isinstance(label, bool):
if label:
_frame_label = "Frame obtained for the " + session_label
_snr_label = "S/N map obtained for the " + session_label
else:
_frame_label = ""
_snr_label = ""
else:
_frame_label, _snr_label = label
if self.sessions[session_id].snr_map is not None:
plot_frames(
(
self.sessions[session_id].frame,
self.sessions[session_id].snr_map,
),
label=(_frame_label, _snr_label),
)
else:
plot_frames(self.sessions[session_id].frame, label=_frame_label)
[docs]
@dataclass
class PostProc(BaseEstimator):
"""
Base post-processing algorithm class.
Does not need an ``__init__`` because as a parent class for every algorithm object,
there is no reason to create a PostProc object. Inherited classes benefit from the
``dataclass_builder`` support for their initialization and no further methods are
needed to create those.
The PostProc is still very useful as it provides crucial utility common to all the
inherited objects, such as :
- establishing a list of attributes which need to be calculated
- updating the dataset used for the algorithm if needed
- calculating the signal-to-noise ratio map after a corrected frame has been
generated
- setting up parameters for the algorithm.
"""
dataset: Dataset = None
verbose: bool = True
results: PPResult = None
frame_final: np.ndarray = None
signf: float = None
def _explicit_dataset(self):
"""
Assign specific attributes from dataset to self.
Many functions wrapped by the PostProc objects do not interact with a dataset
but with their inner values instead : cube, fwhm, angle_list, etc. Those share
different names in the functions wrapped and in the dataset, see the
`EXPLICIT_PARAMS` constant to see the differencies.
"""
for self_name, data_name in EXPLICIT_PARAMS.items():
dataset_value = getattr(self.dataset, data_name)
setattr(self, self_name, dataset_value)
def _create_parameters_dict(self, parent_class: any) -> dict:
"""
Create a dictionnary with the parameters used inside of the PostProc object.
Parameters
----------
parent_class: class
Parent of the object that contains the parameters used by
that object.
Returns
-------
params_dict: dict
Parameters used by the object under dictionnary form.
"""
params_dict = {}
for attr_name in vars(self):
if hasattr(parent_class, attr_name):
attr_value = getattr(self, attr_name)
params_dict[attr_name] = attr_value
return params_dict
[docs]
def print_parameters(self) -> None:
"""Print out the parameters of the algorithm."""
for key, value in self.__dict__.items():
if not isinstance(value, np.ndarray):
print(f"{key} : {value}")
else:
print(f"{key} : numpy ndarray (not shown)")
# TODO: write test
[docs]
def compute_significance(self, source_xy: Tuple[float] = None) -> None:
"""
Compute the significance of a detection.
Parameters
----------
source_xy: Tuple of floats
Coordinates of the detection.
"""
if self.snr_map is None:
self.make_snrmap()
snr_sig = snr(self.frame_final, source_xy=source_xy, fwhm=self.fwhm)
center_y, center_x = frame_center(self.snr_map)
radius = np.sqrt(
(center_y - source_xy[1]) ** 2 + (center_x - source_xy[0]) ** 2
)
self.signf = significance(snr_sig, radius, self.fwhm, student_to_gauss=True)
print(r"{:.1f} sigma detection".format(self.signf))
def _update_dataset(self, dataset: Optional[Dataset] = None) -> None:
"""
Handle a dataset passed to ``run()``.
It is possible to specify a dataset using the constructor, or using the
``run()`` function. This helper function checks that there is a dataset
to work with.
Parameters
----------
dataset : Dataset or None, optional
"""
if dataset is not None:
print(
"A new dataset was provided to run, all previous results were cleared."
)
self.dataset = dataset
self._reset_results()
elif self.dataset is None:
raise AttributeError(
"No dataset was specified ! Please give a valid dataset inside the"
"builder of the associated algorithm or inside the `run()` function."
)
else:
print("No changes were made to the dataset.")
[docs]
def get_params_from_results(self, session_id: int) -> None:
"""
Copy a previously registered configuration from the results to the object.
Parameters
----------
session_id : int
The ID of the session to load the configuration from.
"""
if self.results is None:
raise AttributeError(
"No results were saved yet ! Please give the object a PPResult instance"
" and run the object at least once."
)
res = self.results.sessions
if session_id > len(res) or res == []:
raise ValueError(
f"ID is higher than the current number of sessions registered. "
f"There are {len(self.results.sessions)} saved now.",
)
if res[session_id].algo_name not in self._algo_name:
raise ValueError(
"The function used for that session does not match your object."
" Please choose a session with a corresponding function."
)
for key, value in res[session_id].parameters.items():
setattr(self, key, value)
print("Configuration loaded :")
print_algo_params(res[session_id].parameters)
# TODO : identify the problem around the element `_repr_html_`
def _get_calculations(self, debug=False) -> dict:
"""
Get a list of all attributes which are *calculated*.
This iterates over all the elements in an object and finds the functions
which were decorated with ``@calculates`` (which are identified by the
function attribute ``_calculates``). It then stores the calculated
attributes, together with the corresponding method, and returns it.
Returns
-------
calculations : dict
Dictionary mapping a single "calculated attribute" to the method
which calculates it.
"""
calculations = {}
for element in dir(self):
# BLACKMAGIC : _repr_html_ must be skipped
"""
`_repr_html_` is an element of the directory of the PostProc object
which causes the search of calculated attributes to overflow,
looping indefinitely and never reaching the actual elements
containing those said attributes. It will be skipped until the issue
has been properly identified and fixed. You can set debug=True to
observe how the directory loops after reaching that element -
acknowledging you are not skipping it.
"""
if element not in PROBLEMATIC_ATTRIBUTE_NAMES:
try:
if debug:
print(
"directory element : ",
element,
", calculations list : ",
calculations,
)
for k in getattr(getattr(self, element), "_calculates"):
calculations[k] = element
except AttributeError:
pass
# below can be commented after debug
else:
if debug:
print(
"directory element SKIPPED: ",
element,
", calculations list : ",
calculations,
)
return calculations
def _reset_results(self) -> None:
"""
Remove all calculated results from the object.
By design, the PostProc's can be initialized without a dataset,
so the dataset can be provided to the ``run`` method. This makes it
possible to run the same algorithm on multiple datasets. In order not to
keep results from an older ``run`` call when working on a new dataset,
the stored results are reset using this function every time the ``run``
method is called.
"""
for attr in self._get_calculations():
try:
delattr(self, attr)
except AttributeError:
pass # attribute/result was not calculated yet. Skip.
def __getattr__(self, attr: str) -> NoReturn:
"""
``__getattr__`` is only called when an attribute does *not* exist.
Catching this event allows us to output proper error messages when an
attribute was not calculated yet.
"""
calculations = self._get_calculations()
if attr in calculations:
msg = f"The {attr} was not calculated yet. "
msg += f"Call {calculations[attr]} first."
raise AttributeError(msg)
# this raises a regular AttributeError:
return self.__getattribute__(attr)
def _show_attribute_help(self, function_name: Callable) -> None:
"""
Print information about the attributes a method calculated.
This is called *automatically* when a method is decorated with
``@calculates``.
Parameters
----------
function_name : string
The name of the method.
"""
calculations = self._get_calculations()
print("These attributes were just calculated:")
for attr, func in calculations.items():
if hasattr(self, attr) and function_name == func:
print(f"\t{attr}")
not_calculated_yet = [
(a, f)
for a, f in calculations.items()
if (f not in self._called_calculators and not hasattr(self, a))
]
if len(not_calculated_yet) > 0:
print("The following attributes can be calculated now:")
for attr, func in not_calculated_yet:
print(f"\t{attr}\twith .{func}()")
[docs]
@calculates("snr_map", "detection_map")
def make_snrmap(
self,
approximated: Optional[bool] = False,
plot: Optional[bool] = False,
known_sources: Optional[Union[Tuple, Tuple[Tuple]]] = None,
nproc: Optional[int] = None,
) -> None:
"""
Calculate a S/N map from ``self.frame_final``.
Parameters
----------
results : PPResult object, optional
Container for the results of the algorithm. May hold the parameters used,
as well as the ``frame_final`` (and the ``snr_map`` if generated).
approximated : bool, optional
If True, a proxy to the S/N calculation will be used. If False, the
Mawet et al. 2014 definition is used.
plot : bool, optional
If True plots the S/N map. True by default.
known_sources : None, tuple or tuple of tuples, optional
To take into account existing sources. It should be a tuple of
float/int or a tuple of tuples (of float/int) with the coordinate(s)
of the known sources.
nproc : int or None
Number of processes for parallel computing.
verbose: bool, optional
Whether to print timing or not.
Note
----
This is needed for "classic" algorithms that produce a final residual
image in their ``.run()`` method. To obtain a "detection map", which can
be used for counting true/false positives, a SNR map has to be created.
For other algorithms (like ANDROMEDA) which directly create a SNR or a
probability map, this method should be overwritten and thus disabled.
"""
if self.dataset.cube.ndim == 4:
fwhm = np.mean(self.dataset.fwhm)
else:
fwhm = self.dataset.fwhm
self.snr_map = snrmap(
self.frame_final,
fwhm,
approximated,
plot=plot,
known_sources=known_sources,
nproc=nproc,
verbose=self.verbose,
)
self.detection_map = self.snr_map
if self.results is not None:
self.results.register_session(frame=self.frame_final, snr_map=self.snr_map)
[docs]
def save(self, filename: str) -> None:
"""
Pickle the algo object and save it to disk.
Note that this also saves the associated ``self.dataset``, in a
non-optimal way.
"""
with open(filename, "wb") as file:
pickle.dump(self, file)
[docs]
@calculates("frame_final")
def run(self) -> None:
"""
Run the algorithm. Should at least set `` self.frame_final``.
Note
----
This is the required signature of the ``run`` call. Child classes can
add their own keyword arguments if needed.
"""
raise NotImplementedError