Source code for vip_hci.objects.pppca

#! /usr/bin/env python
"""Module for the post-processing PCA algorithm."""

__author__ = "Thomas Bédrine"
__all__ = ["PCABuilder", "PPPCA"]

from typing import Tuple, Optional, List
from dataclasses import dataclass, field

import numpy as np
from pandas import DataFrame
from dataclass_builder import dataclass_builder

from .dataset import Dataset
from .postproc import PostProc
from ..psfsub import (
    pca,
    pca_annular,
    pca_grid,
    pca_annulus,
    PCA_Params,
    PCA_ANNULAR_Params,
)
from ..config.paramenum import Adimsdi, ReturnList, Runmode
from ..config.utils_conf import algo_calculates_decorator as calculates
from ..config.utils_param import setup_parameters


[docs] @dataclass class PPPCA(PostProc, PCA_Params, PCA_ANNULAR_Params): """ Post-processing PCA algorithm, compatible with various options. Depending on what mode you need, parameters may vary. Check the list below to ensure which arguments are required. Currently, four variations of the PCA can be called : - full-frame PCA - annular PCA - grid PCA - single annulus PCA. Some parameters are common to several variations. Common parameters ----------------- full_output: bool, optional Whether to return the final median combined image only or with other intermediate arrays. _algo_name: str, optional Name of the algorithm wrapped by the object. Grid parameters --------------- range_pcs : None or tuple, optional The interval of PCs to be tried. Refer to ``vip_hci.psfsub.pca_grid`` for more information. mode : {'fullfr', 'annular'}, optional Mode for PCA processing (full-frame or just in an annulus). fmerit : {'px', 'max', 'mean'} The function of merit to be maximized. 'px' is *source_xy* pixel's SNR, 'max' the maximum SNR in a FWHM circular aperture centered on ``source_xy`` and 'mean' is the mean SNR in the same circular aperture. plot : bool, optional Whether to plot the SNR and flux as functions of PCs and final PCA frame or not. save_plot: string If provided, the pc optimization plot will be saved to that path. initial_4dshape : None or tuple, optional Shape of the initial ADI+mSDI cube. exclude_negative_lobes : bool, opt Whether to include the adjacent aperture lobes to the tested location or not. Can be set to True if the image shows significant neg lobes. Single annulus parameters ------------------------- r_guess : float Radius of the annulus in pixels. """ # Common parameters/returns _algo_name: List[str] = field( default_factory=lambda: [ "pca", "pca_annular", "pca_grid", "pca_annulus", ] ) cube_sig: np.ndarray = None cube_residuals: np.ndarray = None cube_residuals_der: np.ndarray = None full_output = True # Full-frame returns pcs: np.ndarray = None cube_residuals_per_channel: np.ndarray = None cube_residuals_per_channel_der: np.ndarray = None cube_residuals_resc: np.ndarray = None final_residuals_cube: np.ndarray = None medians: np.ndarray = None # Grid parameters frames_final: np.ndarray = None range_pcs: Tuple[int] = None mode: str = "fullfr" fmerit: str = "mean" plot: bool = True save_plot: str = None exclude_negative_lobes: bool = False initial_4dshape: Tuple = None dataframe: DataFrame = None pc_list: List = None opt_number_pc: int = None # Single annulus parameters annulus_width: float = None # Note: also used for Grid in annular mode r_guess: float = None # TODO: write test
[docs] @calculates( "frame_final", "cube_reconstructed", "cube_residuals", "cube_residuals_der", "pcs", "cube_residuals_per_channel", "cube_residuals_per_channel_der", "cube_residuals_resc", "final_residuals_cube", "medians", "dataframe", "opt_number_pc", ) def run( self, runmode: Optional[str] = Runmode.CLASSIC, dataset: Optional[Dataset] = None, nproc: Optional[int] = 1, verbose: Optional[bool] = True, full_output: Optional[bool] = True, **rot_options: Optional[dict], ): """ Run the post-processing PCA algorithm for model PSF subtraction. Depending on the desired mode - full-frame or annular - parameters used will diverge, and calculated attributes may vary as well. In full-frame case : 3D case: cube_reconstructed cube_residuals cube_residuals_der 3D case, source_xy is None: cube_residuals pcs 4D case, adimsdi="double": cube_residuals_per_channel cube_residuals_per_channel_der 4D case, adimsdi="single": cube_residuals cube_residuals_resc Parameters ---------- runmode : Enum, see ``vip_hci.config.paramenum.Runmode`` Mode of execution for the PCA. dataset : Dataset, optional Dataset to process. If not provided, ``self.dataset`` is used (as set when initializing this object). nproc : int, optional verbose : bool, optional If True prints to stdout intermediate info. full_output: boolean, optional Whether to return the final median combined image only or with other intermediate arrays. rot_options: dictionary, optional Dictionary with optional keyword values for "border_mode", "mask_val", "edge_blend", "interp_zeros", "ker" (see documentation of ``vip_hci.preproc.frame_rotate``) """ self.snr_map = None self._update_dataset(dataset) if self.dataset.fwhm is None: raise ValueError("`fwhm` has not been set") self._explicit_dataset() self.full_output = full_output match (runmode): case Runmode.CLASSIC: # TODO : review the wavelengths attribute to be a scale_list instead params_dict = self._create_parameters_dict(PCA_Params) all_params = {"algo_params": self, **rot_options} res = pca(**all_params) self._find_pca_mode(res=res) if self.results is not None and self.frame_final is not None: self.results.register_session( params=params_dict, frame=self.frame_final, algo_name=self._algo_name[0], ) case Runmode.ANNULAR: if self.nproc is None: self.nproc = nproc params_dict = self._create_parameters_dict(PCA_ANNULAR_Params) all_params = {"algo_params": self, **rot_options} res = pca_annular(**all_params) self.cube_residuals = res[0] self.cube_residuals_der = res[1] if isinstance(res[2], list): self.frames_final = res[2] else: self.frame_final = res[2] if self.results is not None and self.frame_final is not None: self.results.register_session( params=params_dict, frame=self.frame_final, algo_name=self._algo_name[1], ) case Runmode.GRID: add_params = { "full_output": full_output, "verbose": verbose, } func_params = setup_parameters( params_obj=self, fkt=pca_grid, **add_params ) res = pca_grid(**func_params, **rot_options) if self.source_xy is not None and self.fwhm is not None: ( self.cube_residuals, self.frame_final, self.dataframe, self.opt_number_pc, ) = res if self.results is not None: self.results.register_session( params=func_params, frame=self.frame_final, algo_name=self._algo_name[2], ) elif self.full_output: ( self.final_residuals_cube, self.pc_list ) = res else: self.final_residuals_cube = res case Runmode.ANNULUS: add_params = { "angs": self.angle_list, } func_params = setup_parameters( params_obj=self, fkt=pca_annulus, **add_params ) res = pca_annulus(**func_params, **rot_options) self.frame_final = res if self.results is not None: self.results.register_session( params=func_params, frame=self.frame_final, algo_name=self._algo_name[3], ) case _: raise ValueError("Invalid run mode selected.")
def _find_pca_mode(self, res): """ Identify the mode of PCA used and extracts return elements accordingly. Nine modes are currently known and each of them looks at specific conditions. Every mode and its set of conditions is verified to be True or not, and associates its return elements via the `match...case` if recognized. Parameters ---------- res: any The return of the PCA function, can consist of a multitude of items. """ conditions = { "cube": isinstance(self.cube, np.ndarray), "scale": self.scale_list is not None, "adimsdidouble": self.adimsdi == Adimsdi.DOUBLE, "adimsdisingle": self.adimsdi == Adimsdi.SINGLE, "ncompunit": isinstance(self.ncomp, (float, int)), "ncompit": isinstance(self.ncomp, (tuple, list)), "source": self.source_xy is not None, "nosource": self.source_xy is None, "reforsource": self.cube_ref is not None or self.source_xy is None, "nobatch": self.batch is None, "batch": self.batch is not None, "cubeorscale": isinstance(self.cube, str) or self.scale_list is None, } pca_modes = { ReturnList.ADIMSDI_DOUBLE: conditions["cube"] and conditions["scale"] and conditions["adimsdidouble"], ReturnList.ADIMSDI_SINGLE_NO_GRID: conditions["cube"] and conditions["scale"] and conditions["adimsdisingle"] and conditions["ncompunit"], ReturnList.ADIMSDI_SINGLE_GRID_NO_SOURCE: conditions["cube"] and conditions["scale"] and conditions["adimsdisingle"] and conditions["ncompit"] and conditions["nosource"], ReturnList.ADIMSDI_SINGLE_GRID_SOURCE: conditions["cube"] and conditions["scale"] and conditions["adimsdisingle"] and conditions["ncompit"] and conditions["source"], ReturnList.ADI_FULLFRAME_GRID: conditions["cubeorscale"] and conditions["reforsource"] and conditions["nobatch"] and conditions["ncompit"], ReturnList.ADI_INCREMENTAL_BATCH: conditions["cubeorscale"] and conditions["reforsource"] and conditions["batch"], ReturnList.ADI_FULLFRAME_STANDARD: conditions["cubeorscale"] and conditions["reforsource"] and conditions["nobatch"] and conditions["ncompunit"], ReturnList.PCA_GRID_SN: conditions["cubeorscale"] and conditions["source"] and conditions["ncompit"], ReturnList.PCA_ROT_THRESH: conditions["cubeorscale"] and conditions["source"] and conditions["ncompunit"], } pca_mode = None for mode, state in pca_modes.items(): if state: pca_mode = mode break match (pca_mode): case ReturnList.ADIMSDI_DOUBLE: self.frame_final, self.cube_residuals, self.cube_residuals_der = res case ReturnList.ADIMSDI_SINGLE_NO_GRID: self.frame_final, self.cube_residuals, _ = res case ReturnList.ADIMSDI_SINGLE_GRID_NO_SOURCE: self.final_residuals_cube, self.frame_final, _ = res case ReturnList.ADIMSDI_SINGLE_GRID_SOURCE: self.final_residuals_cube, self.pc_list = res case ReturnList.ADI_FULLFRAME_GRID: if self.cube.ndim == 4: self.frames_final, self.pc_list, _ = res else: self.frames_final, self.pc_list = res case ReturnList.ADI_INCREMENTAL_BATCH: if self.cube.ndim == 4: self.frame_final, self.pcs, self.medians, _ = res else: self.frame_final, self.pcs, self.medians = res case ReturnList.ADI_FULLFRAME_STANDARD: if self.cube.ndim == 4: ( self.frame_final, self.pcs, self.cube_reconstructed, self.cube_residuals, self.cube_residuals_der, _, ) = res else: ( self.frame_final, self.pcs, self.cube_reconstructed, self.cube_residuals, self.cube_residuals_der, ) = res case ReturnList.PCA_GRID_SN: if self.cube.ndim == 4: self.final_residuals_cube, self.frame_final, _, self.opt_number_pc = res else: self.final_residuals_cube, self.frame_final, _ = res case ReturnList.PCA_ROT_THRESH: if self.cube.ndim == 4: ( self.frame_final, self.cube_reconstructed, self.cube_residuals, self.cube_residuals_der, _, ) = res else: ( self.frame_final, self.cube_reconstructed, self.cube_residuals, self.cube_residuals_der, ) = res case _: raise RuntimeError("No PCA mode could be identified.")
PCABuilder = dataclass_builder(PPPCA)