Source code for vip_hci.psfsub.nmf_fullfr

#! /usr/bin/env python
Module with PSF reference approximation using Non-negative matrix factorization
for ADI and RDI data, in full frames.

.. [LEE99]
   | Lee & Seung 1999
   | **Learning the parts of objects by non-negative matrix factorization**
   | *Nature, Volume 401, Issue 6755, pp. 788-791*
   | `


__author__ = "Thomas Bédrine, Carlos Alberto Gomez Gonzalez, Valentin Christiaens"
__all__ = ["nmf", "NMF_Params"]

import numpy as np
from sklearn.decomposition import NMF
from dataclasses import dataclass, field
from enum import Enum
from typing import Tuple, List
from ..config.utils_param import setup_parameters, separate_kwargs_dict
from ..config.paramenum import Collapse, HandleNeg, Initsvd, ALGO_KEY
from ..config import timing, time_ini
from ..preproc import cube_derotate, cube_collapse
from ..preproc.derotation import _compute_pa_thresh, _find_indices_adi
from ..var import (prepare_matrix, reshape_matrix, frame_center, dist,
                   matrix_scaling, mask_circle)

[docs] @dataclass class NMF_Params: """ Set of parameters for the NMF full-frame algorithm. See function `nmf` below for the documentation. """ cube: np.ndarray = None angle_list: np.ndarray = None cube_ref: np.ndarray = None ncomp: int = 1 scaling: Enum = None max_iter: int = 10000 random_state: int = None mask_center_px: int = None source_xy: Tuple[int] = None delta_rot: float = 1 fwhm: float = 4 init_svd: Enum = Initsvd.NNDSVD collapse: Enum = Collapse.MEDIAN full_output: bool = False verbose: bool = True cube_sig: np.ndarray = None handle_neg: Enum = HandleNeg.MASK nmf_args: dict = field(default_factory=lambda: {})
[docs] def nmf(*all_args: List, **all_kwargs: dict): """Non Negative Matrix Factorization [LEE99]_ for ADI sequences [GOM17]_. Alternative to the full-frame ADI-PCA processing that does not rely on SVD or ED for obtaining a low-rank approximation of the datacube. This function embeds the scikit-learn NMF algorithm solved through either the coordinate descent or the multiplicative update method. Parameters ---------- all_args: list, optional Positionnal arguments for the NMF algorithm. Full list of parameters below. all_kwargs: dictionary, optional Mix of keyword arguments that can initialize a NMFParams and the optional 'rot_options' dictionnary, with keyword values for "border_mode", "mask_val", "edge_blend", "interp_zeros", "ker" (see documentation of ``vip_hci.preproc.frame_rotate``). Can also contain a NMFParams named as `algo_params`. NMF parameters ---------- cube : numpy ndarray, 3d Input cube. angle_list : numpy ndarray, 1d Corresponding parallactic angle for each frame. cube_ref : numpy ndarray, 3d, optional Reference library cube. For Reference Star Differential Imaging. ncomp : int, optional How many components are used as for low-rank approximation of the datacube. scaling : Enum, see `vip_hci.config.paramenum.Scaling` With None, no scaling is performed on the input data before SVD. With "temp-mean" then temporal px-wise mean subtraction is done, with "spat-mean" then the spatial mean is subtracted, with "temp-standard" temporal mean centering plus scaling to unit variance is done and with "spat-standard" spatial mean centering plus scaling to unit variance is performed. max_iter : int optional The number of iterations for the coordinate descent solver. random_state : int or None, optional Controls the seed for the Pseudo Random Number generator. mask_center_px : None or int If None, no masking is done. If an integer > 1 then this value is the radius of the circular mask. source_xy : tuple of int, optional For ADI-PCA, this triggers a frame rejection in the PCA library, with ``source_xy`` as the coordinates X,Y of the center of the annulus where the PA criterion is estimated. When ``ncomp`` is a tuple, a PCA grid is computed and the S/Ns (mean value in a 1xFWHM circular aperture) of the given (X,Y) coordinates are computed. delta_rot : float, optional Factor for tunning the parallactic angle threshold, expressed in FWHM. Default is 1 (excludes 1xFWHM on each side of the considered frame). fwhm : float, optional Known size of the FWHM in pixels to be used. Default value is 4. init_svd: Enum, see `vip_hci.config.paramenum.Initsvd` Method used to initialize the iterative procedure to find H and W. 'nndsvd': non-negative double SVD recommended for sparseness 'nndsvda': NNDSVD where zeros are filled with the average of cube; recommended when sparsity is not desired 'random': random initial non-negative matrix collapse : Enum, see `vip_hci.config.paramenum.Collapse` Sets the way of collapsing the frames for producing a final image. full_output: boolean, optional Whether to return the final median combined image only or with other intermediate arrays. verbose : {True, False}, bool optional If True prints intermediate info and timing. handle_neg: Enum, see `vip_hci.config.paramenum.HandleNeg` Determines how to handle negative values: mask them, set them to zero, or subtract the minimum value in the arrays. Note: 'mask' or 'null' may leave significant artefacts after derotation of residual cube => those options should be used carefully (e.g. with proper treatment of masked values in non-derotated cube of residuals). nmf_args : dictionary, optional Additional arguments for scikit-learn NMF algorithm. See: Returns ------- If full_output is False the final frame is returned. If True the algorithm returns the reshaped NMF components, the reconstructed cube, the residuals, the residuals derotated and the final frame. """ # Separating the parameters of the ParamsObject from the optionnal rot_options class_params, rot_options = separate_kwargs_dict( initial_kwargs=all_kwargs, parent_class=NMF_Params ) # Extracting the object of parameters (if any) algo_params = None if ALGO_KEY in rot_options.keys(): algo_params = rot_options[ALGO_KEY] del rot_options[ALGO_KEY] if algo_params is None: algo_params = NMF_Params(*all_args, **class_params) array = algo_params.cube.copy() if algo_params.verbose: start_time = time_ini() n, y, x = array.shape matrix_ref = None matrix_sig = None # how to handle negative values if algo_params.handle_neg == HandleNeg.MASK: if algo_params.mask_center_px: array = mask_circle(array, algo_params.mask_center_px) if algo_params.cube_sig is not None: yy, xx = np.where( np.amin(array - np.abs(algo_params.cube_sig), axis=0) > 0) else: yy, xx = np.where(np.amin(array, axis=0) > 0) H_tmp = np.zeros([algo_params.ncomp, y, x]) if len(yy) > 0: matrix = array[:, yy, xx] matrix = matrix_scaling(matrix, algo_params.scaling) if algo_params.cube_ref is not None: matrix_ref = algo_params.cube_ref[:, yy, xx] matrix_ref = matrix_scaling(matrix_ref, algo_params.scaling) if algo_params.cube_sig is not None: matrix_sig = algo_params.cube_sig[:, yy, xx] else: raise ValueError("Remove frame(s) with negative values") else: if algo_params.handle_neg == HandleNeg.NULL: if algo_params.cube_sig is not None: array[np.where(array - algo_params.cube_sig < 0)] = 0 algo_params.cube_sig[np.where( array - algo_params.cube_sig < 0)] = 0 else: array[np.where(array < 0)] = 0 elif algo_params.handle_neg == HandleNeg.SUBTR_MIN: if algo_params.cube_sig is not None: array -= np.amin(array - algo_params.cube_sig) else: array -= np.amin(array) else: raise ValueError("Mode to handle neg. pixels not recognized") matrix = prepare_matrix( array, algo_params.scaling, algo_params.mask_center_px, mode="fullfr", verbose=algo_params.verbose, ) if algo_params.cube_ref is not None: matrix_ref = prepare_matrix( algo_params.cube_ref, algo_params.scaling, algo_params.mask_center_px, mode="fullfr", verbose=algo_params.verbose, ) if algo_params.cube_sig is not None: matrix_sig = prepare_matrix( algo_params.cube_sig, algo_params.scaling, algo_params.mask_center_px, mode="fullfr", verbose=algo_params.verbose, ) if algo_params.cube_sig is not None: # derotate residuals_cube = algo_params.cube_sig.copy() else: residuals_cube = np.zeros_like(array) if algo_params.source_xy is None: add_params = { "matrix": matrix, "matrix_ref": matrix_ref, "matrix_sig": matrix_sig, } func_params = setup_parameters( params_obj=algo_params, fkt=_project_subtract, **add_params ) residuals = _project_subtract(**func_params, **algo_params.nmf_args) if algo_params.verbose: timing(start_time) if algo_params.full_output: # reconstructed = residuals[1] H = residuals[2] reconstructed = residuals[1] residuals = residuals[0] recon_cube = residuals_cube.copy() if algo_params.handle_neg == "mask": for fr in range(n): residuals_cube[fr][yy, xx] = residuals[fr] if algo_params.full_output: for fr in range(n): recon_cube[fr][yy, xx] = reconstructed[fr] for pp in range(algo_params.ncomp): H_tmp[pp][yy, xx] = H[pp] H = H_tmp else: for fr in range(n): residuals_cube[fr] = residuals[fr].reshape((y, x)) if algo_params.full_output: recon_cube = reshape_matrix(reconstructed, y, x) H = H.reshape(algo_params.ncomp, y, x) else: if algo_params.delta_rot is None or algo_params.fwhm is None: msg = "Delta_rot or fwhm parameters missing. Needed for the" msg += "PA-based rejection of frames from the library" raise TypeError(msg) recon_cube = np.zeros_like(algo_params.cube) yc, xc = frame_center(algo_params.cube[0], False) x1, y1 = algo_params.source_xy ann_center = dist(yc, xc, y1, x1) pa_thr = _compute_pa_thresh( ann_center, algo_params.fwhm, algo_params.delta_rot) mid_range = ( np.abs(np.amax(algo_params.angle_list) - np.amin(algo_params.angle_list)) / 2 ) if pa_thr >= mid_range - mid_range * 0.1: new_pa_th = float(mid_range - mid_range * 0.1) if algo_params.verbose: msg = "PA threshold {:.2f} is too big, will be set to " msg += "{:.2f}" print(msg.format(pa_thr, new_pa_th)) pa_thr = new_pa_th for fr in range(n): ind = _find_indices_adi(algo_params.angle_list, fr, pa_thr) add_params = { "matrix": matrix, "matrix_ref": matrix_ref, "matrix_sig": matrix_sig, "indices": ind, "frame": fr, } func_params = setup_parameters( params_obj=algo_params, fkt=_project_subtract, **add_params ) res_result = _project_subtract( **func_params, **algo_params.nmf_args) # ! Instead of reshaping, fill frame using get_annulus? if algo_params.full_output: residuals = res_result[0] recon_frame = res_result[1] H = res_result[2] if algo_params.handle_neg == "mask": recon_cube[fr][yy, xx] = recon_frame else: recon_cube[fr] = recon_frame.reshape((y, x)) else: residuals = res_result if algo_params.handle_neg == "mask": residuals_cube[fr][yy, xx] = residuals if fr == n - 1 and algo_params.full_output: for pp in range(algo_params.ncomp): H_tmp[pp][yy, xx] = H[pp] H = H_tmp else: residuals_cube[fr] = residuals.reshape((y, x)) if fr == n - 1 and algo_params.full_output: H = H.reshape(algo_params.ncomp, y, x) if algo_params.verbose: print("Done NMF with sklearn.NMF.") timing(start_time) residuals_cube_ = cube_derotate( residuals_cube, algo_params.angle_list, **rot_options ) frame = cube_collapse(residuals_cube_, mode=algo_params.collapse) if algo_params.verbose: print("Done derotating and combining.") timing(start_time) if algo_params.full_output: return (H, recon_cube, residuals_cube, residuals_cube_, frame) else: return frame
def _project_subtract( matrix, matrix_ref, ncomp, scaling, mask_center_px, verbose, full_output, indices=None, frame=None, matrix_sig=None, max_iter=100, random_state=None, init_svd="nndsvda", **kwargs ): """ PCA projection and model PSF subtraction. Used as a helping function by each of the PCA modes (ADI, ADI+RDI, ADI+mSDI). Parameters ---------- cube : numpy ndarray Input cube. cube_ref : numpy ndarray Reference cube. ncomp : int Number of principal components. scaling : str Scaling of pixel values. See ``pca`` docstrings. mask_center_px : int Masking out a centered circular aperture. verbose : bool Verbosity. full_output : bool Whether to return intermediate arrays or not. indices : list Indices to be used to discard frames (a rotation threshold is used). frame : int Index of the current frame (when indices is a list and a rotation threshold was applied). cube_sig: numpy ndarray, opt Cube with estimate of significant authentic signals. If provided, this will be subtracted from both the cube and the PCA library, before projecting the cube onto the principal components. Returns ------- ref_lib_shape : int [indices is not None, frame is not None] Number of rows in the reference library for the given frame. residuals: numpy ndarray Residuals, returned in every case. reconstructed : numpy ndarray [full_output=True] The reconstructed array. V : numpy ndarray [full_output=True, indices is None, frame is None] The right singular vectors of the input matrix, as returned by ``svd/svd_wrapper()`` """ if matrix_sig is None: matrix_emp = matrix.copy() else: matrix_emp = matrix - matrix_sig if matrix_ref is not None: ref_lib = matrix_ref elif indices is not None and frame is not None: ref_lib = matrix_emp[indices].copy() else: ref_lib = matrix_emp.copy() # to avoid bug, just consider positive values if np.median(ref_lib) < 0: raise ValueError("Mostly negative values in the cube") else: ref_lib[np.where(ref_lib < 0)] = 0 solver = "mu" # if init_svd != 'nndsvd': # solver = 'mu' # else: # solver = 'cd' mod = NMF( n_components=ncomp, solver=solver, init=init_svd, max_iter=max_iter, random_state=random_state, **kwargs ) # a rotation threshold is used (frames are processed one by one) if indices is not None and frame is not None: if ref_lib.shape[0] <= 10: raise RuntimeError( "Less than 10 frames left in the PCA library" ", Try decreasing the parameter delta_rot" ) curr_frame = matrix[frame] # current frame curr_frame_emp = matrix_emp[frame] # H [ncomp, n_pixels] H = # W: coefficients [1, ncomp] W = mod.transform(curr_frame_emp[np.newaxis, ...]) # V = svd_wrapper(ref_lib, svd_mode, ncomp, False) # transformed =, V.T) # reconstructed =, V) reconstructed =, H) residuals = curr_frame - reconstructed if full_output: return residuals, reconstructed, H else: return residuals # the whole matrix is processed at once else: # H [ncomp, n_pixels]: Non-negative components of the data # if cube_ref is not None: H = # else: # H = # W: coefficients [n_frames, ncomp] W = mod.transform(matrix_emp) reconstructed =, H) residuals = matrix - reconstructed # V = svd_wrapper(ref_lib, svd_mode, ncomp, verbose) # transformed =, matrix_emp.T) # reconstructed =, V) # residuals = matrix - reconstructed # residuals_res = reshape_matrix(residuals, y, x) if full_output: return residuals, reconstructed, H else: return residuals