Source code for vip_hci.psfsub.medsub
#! /usr/bin/env python
"""
Implementation of a median subtraction algorithm for model PSF subtraction in
high-contrast imaging sequences. Median-ADI was originally proposed in [MAR06]_,
while median-SDI (also referred to as spectral deconvolution) was proposed in
[SPA02]_ and further developed in [THA07]_.
.. [MAR06]
| Marois et al. 2006
| **Angular Differential Imaging: A Powerful High-Contrast Imaging
Technique**
| *The Astrophysical Journal, Volume 641, Issue 1, pp. 556-564*
| `https://arxiv.org/abs/astro-ph/0512335
<https://arxiv.org/abs/astro-ph/0512335>`_
.. [SPA02]
| Sparks & Ford 2002
| **Imaging Spectroscopy for Extrasolar Planet Detection**
| *The Astrophysical Journal, Volume 578, Issue 1, pp. 543-564*
| `https://arxiv.org/abs/astro-ph/0209078
<https://arxiv.org/abs/astro-ph/0209078>`_
.. [THA07]
| Thatte et al. 2007
| **Very high contrast integral field spectroscopy of AB Doradus C: 9-mag
contrast at 0.2arcsec without a coronagraph using spectral deconvolution**
| *MNRAS, Volume 378, Issue 4, pp. 1229-1236*
| `https://arxiv.org/abs/astro-ph/0703565
<https://arxiv.org/abs/astro-ph/0703565>`_
"""
__author__ = "Carlos Alberto Gomez Gonzalez, Thomas Bédrine"
__all__ = ["median_sub", "MEDIAN_SUB_Params"]
import numpy as np
from multiprocessing import cpu_count
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Union, List
from ..config import time_ini, timing
from ..config.paramenum import Imlib, Interpolation, Collapse, ALGO_KEY
from ..config.utils_conf import pool_map, iterable, print_precision
from ..config.utils_param import setup_parameters, separate_kwargs_dict
from ..preproc import (cube_derotate, cube_collapse, check_pa_vector,
check_scal_vector)
from ..preproc import cube_rescaling_wavelengths as scwave
from ..preproc.derotation import _find_indices_adi, _define_annuli
from ..preproc.rescaling import _find_indices_sdi
from ..var import get_annulus_segments, mask_circle
[docs]
@dataclass
class MEDIAN_SUB_Params:
"""
Set of parameters for the median subtraction module.
See function `median_sub` for documentation.
"""
cube: np.ndarray = None
angle_list: np.ndarray = None
scale_list: np.ndarray = None
flux_sc_list: np.ndarray = None
fwhm: float = 4
radius_int: int = 0
asize: int = 4
delta_rot: int = 1
delta_sep: Union[float, Tuple[float]] = (0.1, 1)
mode: str = "fullfr"
nframes: int = 4
sdi_only: bool = False
imlib: Enum = Imlib.VIPFFT
interpolation: Enum = Interpolation.LANCZOS4
collapse: Enum = Collapse.MEDIAN
nproc: int = 1
full_output: bool = False
verbose: bool = True
[docs]
def median_sub(*all_args: List, **all_kwargs: dict):
"""Perform (smart) median-ADI or median-SDI.
In the case of angular differential imaging (ADI), the algorithm is based on
[MAR06]_. The ADI+IFS method is an extension of this basic idea to
multi-spectral cubes, combining ADI with spectral deconvolution (also called
spectral differential imaging or SDI).
References: [MAR06]_ for median-ADI; [SPA02]_ and [THA07]_ for SDI.
Parameters
----------
all_args: list, optional
Positionnal arguments for the median_sub algorithm. Full list of
parameters is provided below.
all_kwargs: dictionary, optional
Mix of keyword arguments that can initialize a MEDIAN_SUB_Params and the
optional ``rot_options`` dictionary (with keywords ``border_mode``,
``mask_val``, ``edge_blend``, ``interp_zeros``, ``ker``; see docstrings
of ``vip_hci.preproc.frame_rotate``). Can also contain a
MEDIAN_SUB_Params object/dictionary named ``algo_params``.
Parameters
----------
cube : numpy ndarray, 3d
Input cube.
angle_list : numpy ndarray, 1d
Corresponding parallactic angle for each frame.
scale_list : numpy ndarray, 1d, optional
If provided, triggers mSDI reduction. These should be the scaling
factors used to re-scale the spectral channels and align the speckles
in case of IFS data (ADI+mSDI cube). Usually, these can be approximated
by the last channel wavelength divided by the other wavelengths in the
cube (more thorough approaches can be used to get the scaling factors,
e.g. with ``vip_hci.preproc.find_scal_vector``).
flux_sc_list : numpy ndarray, 1d
In the case of IFS data (ADI+SDI), this is the list of flux scaling
factors applied to each spectral frame after geometrical rescaling.
These should be set to either the ratio of stellar fluxes between the
last spectral channel and the other channels, or to the second output
of `preproc.find_scal_vector` (when using 2 free parameters). If not
provided, the algorithm will still work, but with a lower efficiency
at subtracting the stellar halo.
fwhm : float or 1d numpy array
Known size of the FWHM in pixels to be used. Default is 4.
radius_int : int, optional
The radius of the innermost annulus. By default is 0, if >0 then the
central circular area is discarded.
asize : int, optional
The size of the annuli, in pixels.
delta_rot : int, optional
Factor for increasing the parallactic angle threshold, expressed in
FWHM. Default is 1 (excludes 1 FWHM on each side of the considered
frame).
delta_sep : float or tuple of floats, optional
The threshold separation in terms of the mean FWHM (for ADI+mSDI data).
If a tuple of two values is provided, they are used as the lower and
upper intervals for the threshold (grows as a function of the
separation).
mode : {'fullfr', 'annular'}, str optional
In ``fullfr`` mode only the median frame is subtracted, in ``annular``
mode also the 4 closest frames given a PA threshold (annulus-wise) are
subtracted.
nframes : int or None, optional
Number of frames (even value) to be used for building the optimized
reference PSF when working in ``annular`` mode. None by default, which
means that all frames, excluding the thresholded ones, are used.
sdi_only: bool, optional
In the case of IFS data (ADI+SDI), whether to perform median-SDI, or
median-ASDI (default).
imlib : Enum, see `vip_hci.config.paramenum.Imlib`
See the documentation of ``vip_hci.preproc.frame_rotate``.
interpolation : Enum, see `vip_hci.config.paramenum.Interpolation`
See the documentation of the ``vip_hci.preproc.frame_rotate`` function.
collapse : Enum, see `vip_hci.config.paramenum.Collapse`
Sets how temporal residual frames should be combined to produce an
ADI image.
nproc : None or int, optional
Number of processes for parallel computing. If None the number of
processes will be set to cpu_count()/2. By default the algorithm works
in single-process mode.
full_output: bool, optional
Whether to return the final median combined image only or with other
intermediate arrays.
verbose : bool, optional
If True prints to stdout intermediate info.
Returns
-------
cube_out : numpy ndarray, 3d
[full_output=True] The cube of residuals.
cube_der : numpy ndarray, 3d
[full_output=True] The derotated cube of residuals.
frame : numpy ndarray, 2d
Median combination of the de-rotated cube.
"""
# Separating the parameters of the ParamsObject from optional rot_options
class_params, rot_options = separate_kwargs_dict(
initial_kwargs=all_kwargs, parent_class=MEDIAN_SUB_Params
)
# Extracting the object of parameters (if any)
algo_params = None
if ALGO_KEY in rot_options.keys():
algo_params = rot_options[ALGO_KEY]
del rot_options[ALGO_KEY]
if algo_params is None:
algo_params = MEDIAN_SUB_Params(*all_args, **class_params)
global ARRAY
ARRAY = algo_params.cube.copy()
if not (ARRAY.ndim == 3 or ARRAY.ndim == 4):
raise TypeError("Input array is not a 3d or 4d array")
if algo_params.verbose:
start_time = time_ini()
if algo_params.nproc is None:
algo_params.nproc = cpu_count() // 2
algo_params.angle_list = check_pa_vector(algo_params.angle_list)
if ARRAY.ndim == 3:
n, y, _ = ARRAY.shape
if ARRAY.shape[0] != algo_params.angle_list.shape[0]:
msg = "Input vector or parallactic angles has wrong length"
raise TypeError(msg)
# The median frame is first subtracted from each frame
model_psf = np.median(ARRAY, axis=0)
ARRAY -= model_psf
# Depending on the ``mode``
cube_out = ARRAY
if algo_params.mode == "fullfr":
# MASK AFTER DEROTATION TO AVOID ARTEFACTS
# if radius_int > 0:
# cube_out = mask_circle(ARRAY, radius_int, fillwith=np.nan)
# else:
# cube_out = ARRAY
if algo_params.verbose:
print("Median psf reference subtracted")
elif algo_params.mode == "annular":
if algo_params.nframes is not None:
if algo_params.nframes % 2 != 0:
raise TypeError("`nframes` argument must be even value")
n_annuli = int((y / 2 - algo_params.radius_int) / algo_params.asize)
if algo_params.verbose:
print("N annuli = {}, FWHM = {}".format(
n_annuli, algo_params.fwhm))
add_params = {
"ann": iterable(range(n_annuli)),
"n_annuli": n_annuli,
"annulus_width": algo_params.asize,
}
func_params = setup_parameters(
params_obj=algo_params,
fkt=_median_subt_ann_adi,
as_list=True,
**add_params,
)
res = pool_map(
algo_params.nproc,
_median_subt_ann_adi,
msg="Processing annuli:",
progressbar_single=True,
*func_params,
)
res = np.array(res, dtype=object)
mres = res[:, 0]
yy = res[:, 1]
xx = res[:, 2]
# cube_out = np.zeros_like(ARRAY)
# cube_out[:] = np.nan
for ann in range(n_annuli):
cube_out[:, yy[ann], xx[ann]] = mres[ann]
if algo_params.verbose:
print("Optimized median psf reference subtracted")
else:
raise RuntimeError("Mode not recognized")
cube_der = cube_derotate(
cube_out,
algo_params.angle_list,
nproc=algo_params.nproc,
imlib=algo_params.imlib,
interpolation=algo_params.interpolation,
**rot_options,
)
if algo_params.radius_int:
cube_out = mask_circle(cube_out, algo_params.radius_int)
cube_der = mask_circle(cube_der, algo_params.radius_int)
frame = cube_collapse(cube_der, mode=algo_params.collapse)
elif ARRAY.ndim == 4:
z, n, y_in, x_in = ARRAY.shape
if algo_params.scale_list is None:
raise ValueError("Scaling factors vector must be provided")
else:
if np.array(algo_params.scale_list).ndim > 1:
raise ValueError("Scaling factors vector is not 1d")
if not algo_params.scale_list.shape[0] == z:
raise ValueError("Scaling factors vector has wrong length")
if algo_params.flux_sc_list is not None:
if np.array(algo_params.flux_sc_list).ndim > 1:
raise ValueError("Scaling factors vector is not 1d")
if not algo_params.flux_sc_list.shape[0] == z:
raise ValueError("Scaling factors vector has wrong length")
# Exploiting spectral variability (radial movement)
algo_params.fwhm = int(np.round(np.mean(algo_params.fwhm)))
n_annuli = int((y_in / 2 - algo_params.radius_int) / algo_params.asize)
if algo_params.nframes is not None:
if algo_params.nframes % 2 != 0:
raise TypeError("`nframes` argument must be even value")
if algo_params.verbose:
print("{} spectral channels per IFS frame".format(z))
print("First median subtraction exploiting spectral variability")
if algo_params.mode == "annular":
print(
"N annuli = {}, mean FWHM = {:.3f}".format(
n_annuli, algo_params.fwhm
)
)
add_params = {
"fr": iterable(range(n)),
"scal": algo_params.scale_list,
"flux_scal": algo_params.flux_sc_list,
"n_annuli": n_annuli,
"annulus_width": algo_params.asize,
}
func_params = setup_parameters(params_obj=algo_params,
fkt=_median_subt_fr_sdi, as_list=True,
**add_params)
res = pool_map(
algo_params.nproc,
_median_subt_fr_sdi,
*func_params,
)
residuals_cube_channels = np.array(res)
if algo_params.verbose:
timing(start_time)
print("{} ADI frames".format(n))
print("Median subtraction in the ADI fashion")
if algo_params.sdi_only:
cube_out = residuals_cube_channels
else:
if algo_params.mode == "fullfr":
median_frame = np.nanmedian(residuals_cube_channels, axis=0)
cube_out = residuals_cube_channels - median_frame
elif algo_params.mode == "annular":
if algo_params.verbose:
print(
"N annuli = {}, mean FWHM = {:.3f}".format(
n_annuli, algo_params.fwhm
)
)
ARRAY = residuals_cube_channels
add_params = {
"ann": iterable(range(n_annuli)),
"n_annuli": n_annuli,
"annulus_width": algo_params.asize,
}
func_params = setup_parameters(
params_obj=algo_params,
fkt=_median_subt_ann_adi,
as_list=True,
**add_params,
)
res = pool_map(
algo_params.nproc,
_median_subt_ann_adi,
msg="Processing annuli:",
progressbar_single=True,
*func_params,
)
res = np.array(res, dtype=object)
mres = res[:, 0]
yy = res[:, 1]
xx = res[:, 2]
pa_thrs = np.array(res[:, 3])
if algo_params.verbose:
print("PA thresholds: ")
print_precision(pa_thrs)
cube_out = np.zeros_like(ARRAY)
cube_out[:] = np.nan
for ann in range(n_annuli):
cube_out[:, yy[ann], xx[ann]] = mres[ann]
else:
raise RuntimeError("Mode not recognized")
cube_der = cube_derotate(
cube_out,
algo_params.angle_list,
imlib=algo_params.imlib,
interpolation=algo_params.interpolation,
nproc=algo_params.nproc,
**rot_options,
)
if algo_params.radius_int:
cube_der = mask_circle(cube_der, algo_params.radius_int)
frame = cube_collapse(cube_der, mode=algo_params.collapse)
if algo_params.verbose:
print("Done derotating and combining")
timing(start_time)
if algo_params.full_output:
return cube_out, cube_der, frame
else:
return frame
def _median_subt_fr_sdi(
fr,
scal,
flux_scal,
n_annuli,
fwhm,
radius_int,
annulus_width,
delta_sep,
nframes,
imlib,
interpolation,
collapse,
mode,
):
"""Optimized median subtraction on a multi-spectral frame (IFS data)."""
z, n, y_in, x_in = ARRAY.shape
scale_list = check_scal_vector(scal)
multispec_fr = scwave(
ARRAY[:, fr, :, :], scale_list, imlib=imlib, interpolation=interpolation
)[
0
] # rescaled cube
if flux_scal is not None:
for i in range(z):
multispec_fr[i] *= flux_scal[i]
if mode == "annular":
cube_res = np.zeros_like(multispec_fr) # shape (z, resc_y, resc_x)
if isinstance(delta_sep, tuple):
delta_sep_vec = np.linspace(delta_sep[0], delta_sep[1], n_annuli)
else:
delta_sep_vec = [delta_sep] * n_annuli
for ann in range(n_annuli):
if ann == n_annuli - 1:
inner_radius = radius_int + (ann * annulus_width - 1)
else:
inner_radius = radius_int + ann * annulus_width
ann_center = inner_radius + (annulus_width / 2)
indices = get_annulus_segments(
multispec_fr[0], inner_radius, annulus_width
)[0]
yy = indices[0]
xx = indices[1]
matrix = multispec_fr[:, yy, xx] # shape (z, npx_annulus)
for j in range(z):
indices_left = _find_indices_sdi(
scal, ann_center, j, fwhm, delta_sep_vec[ann], nframes
)
matrix_masked = matrix[indices_left]
ref_psf_opt = np.nanmedian(matrix_masked, axis=0)
curr_wv = matrix[j]
subtracted = curr_wv - ref_psf_opt
cube_res[j, yy, xx] = subtracted
elif mode == "fullfr":
median_frame = np.nanmedian(multispec_fr, axis=0)
cube_res = multispec_fr - median_frame
if flux_scal is not None:
for i in range(z):
cube_res[i] /= flux_scal[i]
frame_desc = scwave(
cube_res,
scale_list,
full_output=False,
inverse=True,
y_in=y_in,
x_in=x_in,
imlib=imlib,
interpolation=interpolation,
collapse=collapse,
)
return frame_desc
def _median_subt_ann_adi(ann, angle_list, n_annuli, fwhm, radius_int,
annulus_width, delta_rot, nframes):
"""Optimized median subtraction for a given annulus."""
if ARRAY.ndim == 3:
n = ARRAY.shape[0]
elif ARRAY.ndim == 4:
n = ARRAY.shape[1]
# The annulus is built, and the corresponding PA thresholds for frame
# rejection are calculated. The PA rejection is calculated at center of
# the annulus
pa_thr, inner_radius, _ = _define_annuli(angle_list, ann, n_annuli, fwhm,
radius_int, annulus_width,
delta_rot, 1, False)
if ARRAY.ndim == 3:
indices = get_annulus_segments(ARRAY[0], inner_radius, annulus_width)[0]
elif ARRAY.ndim == 4:
indices = get_annulus_segments(
ARRAY[0, 0], inner_radius, annulus_width)[0]
yy = indices[0]
xx = indices[1]
matrix = ARRAY[:, yy, xx] # shape [n x npx_annulus]
matrix_res = np.zeros_like(matrix)
# A second optimized psf reference is subtracted from each frame.
# For each frame we find ``nframes``, depending on the PA threshold,
# to construct this optimized psf reference
for frame in range(n):
if pa_thr != 0:
indices_left = _find_indices_adi(angle_list, frame, pa_thr, nframes)
matrix_disc = matrix[indices_left]
else:
matrix_disc = matrix
ref_psf_opt = np.nanmedian(matrix_disc, axis=0)
curr_frame = matrix[frame]
subtracted = curr_frame - ref_psf_opt
matrix_res[frame] = subtracted
return matrix_res, yy, xx, pa_thr