Source code for vip_hci.psfsub.loci

#! /usr/bin/env python
"""
Module with a frame differencing algorithm for ADI and ADI+mSDI post-processing.

.. [PUE12]
   | Pueyo et al. 2012
   | **Application of a Damped Locally Optimized Combination of Images Method to
     the Spectral Characterization of Faint Companions Using an Integral Field
     Spectrograph**
   | *The Astrophysical Journal Supplements, Volume 199, p. 6*
   | `https://arxiv.org/abs/1111.6102
     <https://arxiv.org/abs/1111.6102>`_

"""

__author__ = "Carlos Alberto Gomez Gonzalez, Thomas Bédrine"
__all__ = ["xloci", "XLOCI_Params"]

import numpy as np
import scipy as sp
import pandas as pn
from multiprocessing import cpu_count
from sklearn.metrics import pairwise_distances
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Union, List
from ..var import get_annulus_segments
from ..config import time_ini, timing
from ..config.utils_param import setup_parameters, separate_kwargs_dict
from ..config.paramenum import (Metric, Adimsdi, Imlib, Interpolation, Collapse,
                                Solver, ALGO_KEY)
from ..config.utils_conf import pool_map, iterable, Progressbar
from ..preproc import (cube_derotate, cube_collapse, check_pa_vector,
                       check_scal_vector)
from ..preproc.rescaling import _find_indices_sdi
from ..preproc import cube_rescaling_wavelengths as scwave
from ..preproc.derotation import _find_indices_adi, _define_annuli


[docs] @dataclass class XLOCI_Params: """ Set of parameters for the LOCI algorithm. See function `xloci` below for the documentation. """ cube: np.ndarray = None angle_list: np.ndarray = None scale_list: np.ndarray = None fwhm: float = 4 metric: Enum = Metric.MANHATTAN dist_threshold: int = 100 delta_rot: Union[float, Tuple[float]] = (0.1, 1) delta_sep: Union[float, Tuple[float]] = (0.1, 1) radius_int: int = 0 asize: int = 4 n_segments: int = 4 nproc: int = 1 solver: Enum = Solver.LSTSQ tol: float = 1e-2 optim_scale_fact: float = 2 adimsdi: Enum = Adimsdi.SKIPADI imlib: Enum = Imlib.VIPFFT interpolation: Enum = Interpolation.LANCZOS4 collapse: Enum = Collapse.MEDIAN verbose: bool = True full_output: bool = False
[docs] def xloci(*all_args: List, **all_kwargs: dict): """Locally Optimized Combination of Images (LOCI) algorithm as in [LAF07]_. The PSF is modeled (for ADI and ADI+mSDI) with a least-square combination of neighbouring frames (solving the equation a x = b by computing a vector x of coefficients that minimizes the Euclidean 2-norm || b - a x ||^2). This algorithm is also compatible with IFS data to perform LOCI-SDI, in a similar fashion as suggested in [PUE12]_ (albeit without dampening zones). Parameters ---------- all_args: list, optional Positionnal arguments for the LOCI algorithm. Full list of parameters below. all_kwargs: dictionary, optional Mix of keyword arguments that can initialize a LOCIParams and the optional 'rot_options' dictionnary, with keyword values for "border_mode", "mask_val", "edge_blend", "interp_zeros", "ker" (see documentation of ``vip_hci.preproc.frame_rotate``). Can also contain a LOCIParams named as `algo_params`. LOCI parameters ---------- cube : numpy ndarray, 3d or 4d Input cube. angle_list : numpy ndarray, 1d Corresponding parallactic angle for each frame. scale_list : numpy ndarray, 1d, optional If provided, triggers mSDI reduction. These should be the scaling factors used to re-scale the spectral channels and align the speckles in case of IFS data (ADI+mSDI cube). Usually, these can be approximated by the last channel wavelength divided by the other wavelengths in the cube (more thorough approaches can be used to get the scaling factors, e.g. with ``vip_hci.preproc.find_scal_vector``). fwhm : float, optional Size of the FWHM in pixels. Default is 4. metric : Enum, see `vip_hci.config.paramenum.Metric` Distance metric to be used ('cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan', 'correlation', etc). It uses the scikit-learn function ``sklearn.metrics.pairwise.pairwise_distances`` (check its documentation). dist_threshold : int, optional Indices with a distance larger than ``dist_threshold`` percentile will initially discarded. 100 by default. delta_rot : float or tuple of floats, optional Factor for adjusting the parallactic angle threshold, expressed in FWHM. Default is 1 (excludes 1 FWHM on each side of the considered frame). If a tuple of two floats is provided, they are used as the lower and upper intervals for the threshold (grows linearly as a function of the separation). delta_sep : float or tuple of floats, optional The threshold separation in terms of the mean FWHM (for ADI+mSDI data). If a tuple of two values is provided, they are used as the lower and upper intervals for the threshold (grows as a function of the separation). radius_int : int, optional The radius of the innermost annulus. By default is 0, if >0 then the central circular region is discarded. asize : int, optional The size of the annuli, in pixels. n_segments : int or list of int or 'auto', optional The number of segments for each annulus. When a single integer is given it is used for all annuli. When set to 'auto', the number of segments is automatically determined for every annulus, based on the annulus width. nproc : None or int, optional Number of processes for parallel computing. If None the number of processes will be set to cpu_count()/2. By default the algorithm works in single-process mode. solver : Enum, see `vip_hci.config.paramenum.Solver` Choosing the solver of the least squares problem. ``lstsq`` uses the standard scipy least squares solver. ``nnls`` uses the scipy non-negative least-squares solver. tol : float, optional Valid when ``solver`` is set to lstsq. Sets the cutoff for 'small' singular values; used to determine effective rank of a. Singular values smaller than ``tol * largest_singular_value`` are considered zero. Smaller values of ``tol`` lead to smaller residuals (more aggressive subtraction). optim_scale_fact : float, optional If >1, the least-squares optimization is performed on a larger segment, similar to LOCI. The optimization segments share the same inner radius, mean angular position and angular width as their corresponding subtraction segments. adimsdi : Enum, see `vip_hci.config.paramenum.Adimsdi` Changes the way the 4d cubes (ADI+mSDI) are processed. ``skipadi``: the multi-spectral frames are rescaled wrt the largest wavelength to align the speckles and the least-squares model is subtracted on each spectral cube separately. ``double``: a first subtraction is done on the rescaled spectral frames (as in the ``skipadi`` case). Then the residuals are processed again in an ADI fashion. imlib : Enum, see `vip_hci.config.paramenum.Imlib` See the documentation of the ``vip_hci.preproc.frame_rotate`` function. interpolation : Enum, see `vip_hci.config.paramenum.Interpolation` See the documentation of the ``vip_hci.preproc.frame_rotate`` function. collapse : Enum, see `vip_hci.config.paramenum.Collapse` Sets the way of collapsing the frames for producing a final image. verbose: bool, optional If True prints info to stdout. full_output: bool, optional Whether to return the final median combined image only or along with 2 other residual cubes (before and after derotation). Returns ------- cube_res : numpy ndarray, 3d [full_output=True] Cube of residuals. cube_der : numpy ndarray, 3d [full_output=True] Derotated cube of residuals. frame_der_median : numpy ndarray, 2d Median combination of the de-rotated cube of residuals. """ # Separating the parameters of the ParamsObject from the optionnal rot_options class_params, rot_options = separate_kwargs_dict( initial_kwargs=all_kwargs, parent_class=XLOCI_Params ) # Extracting the object of parameters (if any) algo_params = None if ALGO_KEY in rot_options.keys(): algo_params = rot_options[ALGO_KEY] del rot_options[ALGO_KEY] if algo_params is None: algo_params = XLOCI_Params(*all_args, **class_params) global ARRAY ARRAY = algo_params.cube if algo_params.verbose: start_time = time_ini() # ADI datacube if algo_params.cube.ndim == 3: func_params = setup_parameters(params_obj=algo_params, fkt=_leastsq_adi) res = _leastsq_adi(**func_params) if algo_params.verbose: timing(start_time) if algo_params.full_output: cube_res, cube_der, frame = res return cube_res, cube_der, frame else: frame = res return frame # ADI+mSDI (IFS) datacubes elif algo_params.cube.ndim == 4: z, n, y_in, x_in = algo_params.cube.shape algo_params.fwhm = int(np.round(np.mean(algo_params.fwhm))) n_annuli = int((y_in / 2 - algo_params.radius_int) / algo_params.asize) # Processing separately each wavelength in ADI fashion if algo_params.adimsdi == Adimsdi.SKIPADI: if algo_params.verbose: print("ADI lst-sq modeling for each wavelength individually") print("{} frames per wavelength".format(n)) cube_res = np.zeros((z, y_in, x_in)) for z in Progressbar(range(z)): ARRAY = algo_params.cube[z] add_params = { "cube": algo_params.cube[z], "verbose": False, "full_output": False, } func_params = setup_parameters( params_obj=algo_params, fkt=_leastsq_adi, **add_params ) res = _leastsq_adi(**func_params) cube_res[z] = res frame = cube_collapse(cube_res, algo_params.collapse) if algo_params.verbose: print("Done combining the residuals") timing(start_time) if algo_params.full_output: return cube_res, frame else: return frame else: if algo_params.scale_list is None: raise ValueError("Scaling factors vector must be provided") else: if np.array(algo_params.scale_list).ndim > 1: raise ValueError("Scaling factors vector is not 1d") if not algo_params.scale_list.shape[0] == z: raise ValueError("Scaling factors vector has wrong length") if algo_params.verbose: print("SDI lst-sq modeling exploiting the spectral variability") print("{} spectral channels per IFS frame".format(z)) print( "N annuli = {}, mean FWHM = " "{:.3f}".format(n_annuli, algo_params.fwhm) ) add_params = {"fr": iterable( range(n)), "scal": algo_params.scale_list} func_params = setup_parameters( params_obj=algo_params, fkt=_leastsq_sdi_fr, as_list=True, **add_params ) res = pool_map( algo_params.nproc, _leastsq_sdi_fr, *func_params, ) cube_out = np.array(res) # Choosing not to exploit the rotational variability if algo_params.adimsdi == Adimsdi.SKIPADI: if algo_params.verbose: print("Skipping the ADI least-squares subtraction") print("{} ADI frames".format(n)) timing(start_time) cube_der = cube_derotate( cube_out, algo_params.angle_list, imlib=algo_params.imlib, interpolation=algo_params.interpolation, nproc=algo_params.nproc, **rot_options, ) frame = cube_collapse(cube_der, mode=algo_params.collapse) # Exploiting rotational variability elif algo_params.adimsdi == Adimsdi.DOUBLE: if algo_params.verbose: print("ADI lst-sq modeling exploiting the angular variability") print("{} ADI frames".format(n)) timing(start_time) ARRAY = cube_out add_params = {"cube": cube_out} func_params = setup_parameters( params_obj=algo_params, fkt=_leastsq_adi, **add_params ) res = _leastsq_adi( **func_params, **rot_options, ) if algo_params.full_output: cube_out, cube_der, frame = res else: frame = res if algo_params.verbose: timing(start_time) if algo_params.full_output: return cube_out, cube_der, frame else: return frame
def _leastsq_adi( cube, angle_list, fwhm=4, metric="manhattan", dist_threshold=50, delta_rot=0.5, radius_int=0, asize=4, n_segments=4, nproc=1, solver="lstsq", tol=1e-2, optim_scale_fact=1, imlib="vip-fft", interpolation="lanczos4", collapse="median", verbose=True, full_output=False, **rot_options ): """Least-squares model PSF subtraction for ADI.""" y = cube.shape[1] if not asize < y // 2: raise ValueError("asize is too large") angle_list = check_pa_vector(angle_list) n_annuli = int((y / 2 - radius_int) / asize) if verbose: print("Building {} annuli:".format(n_annuli)) if isinstance(delta_rot, tuple): delta_rot = np.linspace(delta_rot[0], delta_rot[1], num=n_annuli) elif isinstance(delta_rot, (int, float)): delta_rot = [delta_rot] * n_annuli if nproc is None: nproc = cpu_count() // 2 # Hyper-threading doubles the # of cores annulus_width = asize if isinstance(n_segments, int): n_segments = [n_segments] * n_annuli elif n_segments == "auto": n_segments = list() n_segments.append(2) # for first annulus n_segments.append(3) # for second annulus ld = 2 * np.tan(360 / 4 / 2) * annulus_width for i in range(2, n_annuli): # rest of annuli radius = i * annulus_width ang = np.rad2deg(2 * np.arctan(ld / (2 * radius))) n_segments.append(int(np.ceil(360 / ang))) # annulus-wise least-squares combination and subtraction cube_res = np.zeros_like(cube) ayxyx = [] # contains per-segment data pa_thresholds = [] for ann in range(n_annuli): n_segments_ann = n_segments[ann] inner_radius_ann = radius_int + ann * annulus_width # angles pa_threshold = _define_annuli( angle_list, ann, n_annuli, fwhm, radius_int, asize, delta_rot[ann], n_segments_ann, verbose, )[0] # indices indices = get_annulus_segments( cube[0], inner_radius=inner_radius_ann, width=asize, nsegm=n_segments_ann ) ind_opt = get_annulus_segments( cube[0], inner_radius=inner_radius_ann, width=asize, nsegm=n_segments_ann, optim_scale_fact=optim_scale_fact, ) # store segment data for multiprocessing ayxyx += [ ( ann, indices[nseg][0], indices[nseg][1], ind_opt[nseg][0], ind_opt[nseg][1], ) for nseg in range(n_segments_ann) ] pa_thresholds.append(pa_threshold) msg = "Patch-wise least-square combination and subtraction:" # reverse order of processing, as outer segments take longer res_patch = pool_map( nproc, _leastsq_patch, iterable(ayxyx[::-1]), pa_thresholds, angle_list, metric, dist_threshold, solver, tol, verbose=verbose, msg=msg, progressbar_single=True, ) for patch in res_patch: matrix_res, yy, xx = patch cube_res[:, yy, xx] = matrix_res cube_der = cube_derotate( cube_res, angle_list, imlib, interpolation, nproc=nproc, **rot_options ) frame_der_median = cube_collapse(cube_der, collapse) if verbose: print("Done processing annuli") if full_output: return cube_res, cube_der, frame_der_median else: return frame_der_median def _leastsq_patch(ayxyx, pa_thresholds, angles, metric, dist_threshold, solver, tol): """Helper function for _leastsq_ann. Parameters ---------- axyxy : tuple This tuple contains all per-segment data. pa_thresholds : list of list This is a per-annulus list of thresholds. angles, metric, dist_threshold, solver, tol These parameters are the same for each annulus or segment. """ iann, yy, xx, yy_opt, xx_opt = ayxyx pa_threshold = pa_thresholds[iann] values = ARRAY[:, yy, xx] # n_frames x n_pxs_segment values_opt = ARRAY[:, yy_opt, xx_opt] n_frames = ARRAY.shape[0] if dist_threshold < 100: mat_dists_ann_full = pairwise_distances(values, metric=metric) else: mat_dists_ann_full = np.ones((values.shape[0], values.shape[0])) if pa_threshold > 0: mat_dists_ann = np.zeros_like(mat_dists_ann_full) for i in range(n_frames): ind_fr_i = _find_indices_adi(angles, i, pa_threshold, None, False) mat_dists_ann[i][ind_fr_i] = mat_dists_ann_full[i][ind_fr_i] else: mat_dists_ann = mat_dists_ann_full threshold = np.percentile(mat_dists_ann[mat_dists_ann != 0], dist_threshold) mat_dists_ann[mat_dists_ann > threshold] = np.nan mat_dists_ann[mat_dists_ann == 0] = np.nan matrix_res = np.zeros((values.shape[0], yy.shape[0])) for i in range(n_frames): vector = pn.DataFrame(mat_dists_ann[i]) if vector.sum().values > 0: ind_ref = np.where(~np.isnan(vector))[0] A = values_opt[ind_ref] b = values_opt[i] if solver == "lstsq": try: coef = sp.linalg.lstsq(A.T, b, cond=tol)[0] # SVD method except: coef = sp.optimize.nnls(A.T, b)[0] # if SVD does not work elif solver == "nnls": coef = sp.optimize.nnls(A.T, b)[0] elif solver == "lsq": # TODO coef = sp.optimize.lsq_linear( A.T, b, bounds=(0, 1), method="trf", lsq_solver="lsmr" )["x"] else: raise ValueError("`solver` not recognized") else: msg = "No frames left in the reference set. Try increasing " msg += "`dist_threshold` or decreasing `delta_rot`." raise RuntimeError(msg) recon = np.dot(coef, values[ind_ref]) matrix_res[i] = values[i] - recon return matrix_res, yy, xx def _leastsq_sdi_fr( fr, scal, radius_int, fwhm, asize, n_segments, delta_sep, tol, optim_scale_fact, metric, dist_threshold, solver, imlib, interpolation, collapse, ): """Optimized least-squares based subtraction on a multi-spectral frame (IFS data). """ z, n, y_in, x_in = ARRAY.shape scale_list = check_scal_vector(scal) # rescaled cube, aligning speckles global MULTISPEC_FR MULTISPEC_FR = scwave( ARRAY[:, fr, :, :], scale_list, imlib=imlib, interpolation=interpolation )[0] # Exploiting spectral variability (radial movement) fwhm = int(np.round(np.mean(fwhm))) annulus_width = int(np.ceil(asize)) # equal size for all annuli n_annuli = int(np.floor((y_in / 2 - radius_int) / annulus_width)) if isinstance(n_segments, int): n_segments = [n_segments for _ in range(n_annuli)] elif n_segments == "auto": n_segments = list() n_segments.append(2) # for first annulus n_segments.append(3) # for second annulus ld = 2 * np.tan(360 / 4 / 2) * annulus_width for i in range(2, n_annuli): # rest of annuli radius = i * annulus_width ang = np.rad2deg(2 * np.arctan(ld / (2 * radius))) n_segments.append(int(np.ceil(360 / ang))) cube_res = np.zeros_like(MULTISPEC_FR) # shape (z, resc_y, resc_x) if isinstance(delta_sep, tuple): delta_sep_vec = np.linspace(delta_sep[0], delta_sep[1], n_annuli) else: delta_sep_vec = [delta_sep] * n_annuli for ann in range(n_annuli): if ann == n_annuli - 1: inner_radius = radius_int + (ann * annulus_width - 1) else: inner_radius = radius_int + ann * annulus_width ann_center = inner_radius + (annulus_width / 2) indices = get_annulus_segments( MULTISPEC_FR[0], inner_radius, annulus_width, n_segments[ann] ) ind_opt = get_annulus_segments( MULTISPEC_FR[0], inner_radius, annulus_width, n_segments[ann], optim_scale_fact=optim_scale_fact, ) for seg in range(n_segments[ann]): yy = indices[seg][0] xx = indices[seg][1] segm_res = _leastsq_patch_ifs( seg, indices, ind_opt, scal, ann_center, fwhm, delta_sep_vec[ann], metric, dist_threshold, solver, tol, ) cube_res[:, yy, xx] = segm_res frame_desc = scwave( cube_res, scale_list, full_output=False, inverse=True, y_in=y_in, x_in=x_in, imlib=imlib, interpolation=interpolation, collapse=collapse, ) return frame_desc def _leastsq_patch_ifs( nseg, indices, indices_opt, scal, ann_center, fwhm, delta_sep, metric, dist_threshold, solver, tol, ): """Helper function.""" yy = indices[nseg][0] xx = indices[nseg][1] values = MULTISPEC_FR[:, yy, xx] yy_opt = indices_opt[nseg][0] xx_opt = indices_opt[nseg][0] values_opt = MULTISPEC_FR[:, yy_opt, xx_opt] n_wls = ARRAY.shape[0] if dist_threshold < 100: mat_dists_ann_full = pairwise_distances(values, metric=metric) else: mat_dists_ann_full = np.ones((values.shape[0], values.shape[0])) if delta_sep > 0: mat_dists_ann = np.zeros_like(mat_dists_ann_full) for z in range(n_wls): ind_fr_i = _find_indices_sdi(scal, ann_center, z, fwhm, delta_sep) mat_dists_ann[z][ind_fr_i] = mat_dists_ann_full[z][ind_fr_i] else: mat_dists_ann = mat_dists_ann_full threshold = np.percentile(mat_dists_ann[mat_dists_ann != 0], dist_threshold) mat_dists_ann[mat_dists_ann > threshold] = np.nan mat_dists_ann[mat_dists_ann == 0] = np.nan matrix_res = np.zeros((values.shape[0], yy.shape[0])) for z in range(n_wls): vector = pn.DataFrame(mat_dists_ann[z]) if vector.sum().values != 0: ind_ref = np.where(~np.isnan(vector))[0] A = values_opt[ind_ref] b = values_opt[z] if solver == "lstsq": coef = sp.linalg.lstsq(A.T, b, cond=tol)[0] # SVD method elif solver == "nnls": coef = sp.optimize.nnls(A.T, b)[0] elif solver == "lsq": # TODO coef = sp.optimize.lsq_linear( A.T, b, bounds=(0, 1), method="trf", lsq_solver="lsmr" )["x"] else: raise ValueError("solver not recognized") else: msg = "No frames left in the reference set. Try increasing " msg += "`dist_threshold` or decreasing `delta_sep`." raise RuntimeError(msg) recon = np.dot(coef, values[ind_ref]) matrix_res[z] = values[z] - recon return matrix_res