Source code for vip_hci.preproc.badpixremoval

#! /usr/bin/env python
"""
Module with functions for correcting bad pixels in cubes.

.. [AAC01]
   | Aach & Metzler 2001
   | **Defect interpolation in digital radiography how object-oriented
     transform coding helps**
   | *SPIE, Proceedings Volume 4322, Medical Imaging 2001: Image Processing*
   | `https://doi.org/10.1117/12.431161
     <https://doi.org/10.1117/12.431161>`_

"""
import warnings
__author__ = ('V. Christiaens, Carlos Alberto Gomez Gonzalez, '
              'Srikanth Kompella')
__all__ = ['frame_fix_badpix_isolated',
           'cube_fix_badpix_isolated',
           'cube_fix_badpix_annuli',
           'cube_fix_badpix_clump',
           'cube_fix_badpix_ifs',
           'cube_fix_badpix_interp',
           'frame_fix_badpix_fft']

import numpy as np
from skimage.draw import disk, ellipse
from scipy.ndimage import median_filter
from astropy.stats import sigma_clipped_stats
from multiprocessing import cpu_count
from ..stats import clip_array, sigma_filter
from ..var import frame_center, get_annulus_segments, frame_filter_lowpass
from ..config import timing, time_ini, Progressbar
from ..config.utils_conf import pool_map, iterable
from .rescaling import find_scal_vector, frame_rescaling
from .cosmetics import frame_pad
from multiprocessing import Process
import multiprocessing
from multiprocessing import set_start_method
shared_mem = True
try:
    from multiprocessing import shared_memory
except ImportError:
    print('Failed to import shared_memory from multiprocessing')
    try:
        print('Trying to import shared_memory directly(for python 3.7)')
        import shared_memory
    except ModuleNotFoundError:
        shared_mem = False
        print("WARNING: multiprocessing unavailable for bad pixel correction.")
        print('Either pip install shared-memory38, or upgrade to python>=3.8')

try:
    from numba import njit
    no_numba = False
except ImportError:
    msg = "Numba python bindings are missing."
    warnings.warn(msg, ImportWarning)
    no_numba = True


[docs] def frame_fix_badpix_isolated(array, bpm_mask=None, correct_only=False, sigma_clip=3, num_neig=5, size=5, protect_mask=0, excl_mask=None, cxy=None, mad=False, ignore_nan=True, verbose=True, full_output=False): """ Corrects the bad pixels, marked in the bad pixel mask. The bad pixel is replaced by the median of the adjacent pixels. This function is very fast but works only with isolated (sparse) pixels. Parameters ---------- array : numpy ndarray Input 2d array. bpm_mask : numpy ndarray, optional Input bad pixel map. Should have same size as array. Binary map where the bad pixels have a value of 1. If None is provided a bad pixel map will be created using sigma clip statistics. correct_only : bool, opt If True and bpix_map is provided, will only correct for provided bad pixels. Else, the algorithm will determine (more) bad pixels. sigma_clip : int, optional In case no bad pixel mask is provided all the pixels above and below sigma_clip*STDDEV will be marked as bad. num_neig : int, optional The side of the square window around each pixel where the sigma clipped statistics are calculated (STDDEV and MEDIAN). If the value is equal to 0 then the statistics are computed in the whole frame. size : odd int, optional The size the box (size x size) of adjacent pixels for the median filter. protect_mask : int or float, optional If larger than 0, radius of a circular aperture (at the center of the frames) in which no bad pixels will be identified. This can be useful to protect the star and vicinity. excl_mask : numpy ndarray, optional Binary mask with 1 in areas that should not be considered as good neighbouring pixels during the identification of bad pixels. These should not be considered as bad pixels to be corrected neither (i.e. different to bpm_mask). cxy: None or tuple If protect_mask is True, this is the location of the star centroid in the images. If None, assumes the star is already centered. If a tuple, the location of the star is assumed to be the same in all frames of the cube. mad : {False, True}, bool optional If True, the median absolute deviation will be used instead of the standard deviation. ignore_nan: bool, optional Whether to not consider NaN values as bad pixels. If False, will also correct them. verbose : bool, optional If True additional information will be printed. full_output: bool, {False,True}, optional Whether to return as well the cube of bad pixel maps and the cube of defined annuli. Return ------ frame : numpy ndarray Frame with bad pixels corrected. bpm_mask: 2d array The bad pixel map """ if array.ndim != 2: raise TypeError('Array is not a 2d array or single frame') if size % 2 == 0: raise TypeError('Size of the median blur kernel must be an odd integer') if correct_only and bpm_mask is None: msg = "Bad pixel map should be provided if correct_only is True." raise ValueError(msg) if bpm_mask is not None: msg = "Input bad pixel mask should have same shape as array\n" assert bpm_mask.shape == array.shape, msg bpm_mask = bpm_mask.astype('bool') if excl_mask is None: excl_mask = np.zeros(array.shape, dtype=bool) else: msg = "Input exclusion mask should have same shape as array\n" assert excl_mask.shape == array.shape, msg ind_excl = np.where(excl_mask) if verbose: start = time_ini() if num_neig > 0: neigh = True else: neigh = False frame = array.copy() if cxy is None: cy, cx = frame_center(frame) else: cx, cy = cxy if bpm_mask is None or not correct_only: if bpm_mask is None: bpm_mask = np.zeros(array.shape, dtype=bool) bpm_mask = bpm_mask+excl_mask bpm_mask[np.where(bpm_mask > 1)] = 1 ori_nan_mask = np.where(np.isnan(frame)) ind = clip_array(frame, sigma_clip, sigma_clip, bpm_mask, neighbor=neigh, num_neighbor=num_neig, mad=mad) bpm_mask = np.zeros_like(frame) bpm_mask[ind] = 1 if ignore_nan: bpm_mask[ori_nan_mask] = 0 if protect_mask: cir = disk((cy, cx), protect_mask, shape=bpm_mask.shape) bpm_mask[cir] = 0 bpm_mask[ind_excl] = 0 bpm_mask = bpm_mask.astype('bool') smoothed = median_filter(frame, size, mode='mirror') frame[np.where(bpm_mask)] = smoothed[np.where(bpm_mask)] array_out = frame count_bp = np.sum(bpm_mask) if verbose: msg = "Done replacing {} bad pixels using the median of neighbors" print(msg.format(count_bp)) timing(start) if full_output: return array_out, bpm_mask else: return array_out
[docs] def cube_fix_badpix_isolated(array, bpm_mask=None, correct_only=False, sigma_clip=3, num_neig=5, size=5, frame_by_frame=False, protect_mask=0, excl_mask=None, cxy=None, mad=False, ignore_nan=True, verbose=True, full_output=False, nproc=1): """ Corrects the bad pixels, marked in the bad pixel mask. The bad pixel is replaced by the median of the adjacent pixels. This function is very fast but works only with isolated (sparse) pixels. Parameters ---------- array : numpy ndarray Input 3d array. bpm_mask : numpy ndarray, optional Input bad pixel map. Zeros frame where the bad pixels have a value of 1. If None is provided a bad pixel map will be created per frame using sigma clip statistics. correct_only : bool, opt If True and bpix_map is provided, will only correct for provided bad pixels. Else, the algorithm will determine (more) bad pixels. sigma_clip : int, optional In case no bad pixel mask is provided all the pixels above and below sigma_clip*STDDEV will be marked as bad. num_neig : int, optional The side of the square window around each pixel where the sigma clipped statistics are calculated (STDDEV and MEDIAN). If the value is equal to 0 then the statistics are computed in the whole frame. size : odd int, optional The size the box (size x size) of adjacent pixels for the median filter. frame_by_frame: bool, optional Whether to correct bad pixels frame by frame in the cube. By default it is set to False; the bad pixels are computed on the mean frame of the stack (faster but not necessarily optimal). protect_mask : int or float, optional If larger than 0, radius of a circular aperture (at the center of the frames) in which no bad pixels will be identified. This can be useful to protect the star and vicinity. excl_mask : numpy ndarray, optional Binary mask with 1 in areas that should not be considered as good neighbouring pixels during the identification of bad pixels. These should not be considered as bad pixels to be corrected neither (i.e. different to bpm_mask). cxy: None, tuple or 2d numpy ndarray If protect_mask is True, this is the location of the star centroid in the images. If None, assumes the star is already centered. If a tuple, the location of the star is assumed to be the same in all frames of the cube. If a (n_frames x 2) ndarray, it should contain the xy location of the star in each frame. mad : {False, True}, bool optional If True, the median absolute deviation will be used instead of the standard deviation. ignore_nan: bool, optional Whether to not consider NaN values as bad pixels. If False, will also correct them. verbose : bool, optional If True additional information will be printed. full_output: bool, {False,True}, optional Whether to return as well the cube of bad pixel maps and the cube of defined annuli. nproc: int, optional This feature is added following ADACS update. Refers to the number of processors available for calculations. Choosing a number >1 enables multiprocessing for the correction of frames. This happens only when ``frame_by_frame=True''. Return ------ array_out : numpy ndarray Cube with bad pixels corrected. bpm_mask: 2d or 3d array [if full_output is True] The bad pixel map or the cube of bad pixel maps """ if array.ndim != 3: raise TypeError('Array is not a 3d array or cube') if size % 2 == 0: raise TypeError('Size of the median blur kernel must be an odd integer') if correct_only and bpm_mask is None: msg = "Bad pixel map should be provided if correct_only is True." raise ValueError(msg) if bpm_mask is not None: msg = "Input bad pixel mask should have same last 2 dims as array\n" assert bpm_mask.shape[-2:] == array.shape[-2:], msg bpm_mask = bpm_mask.astype('bool') if verbose: start = time_ini() if num_neig > 0: neigh = True else: neigh = False nz = array.shape[0] if cxy is None: cy, cx = frame_center(array[0]) elif isinstance(cxy, tuple): cx, cy = cxy elif isinstance(cxy, np.ndarray): if cxy.shape[0] != nz or cxy.shape[1] != 2 or cxy.ndim != 2: raise ValueError("cxy does not have right shape") elif not frame_by_frame: msg = "cxy must be a tuple or None if not in frame_by_frame mode" raise ValueError(msg) else: cx = cxy[:, 0] cy = cxy[:, 1] array_out = array.copy() final_bpm = np.zeros_like(array_out, dtype=bool) n_frames = array.shape[0] count_bp = 0 if frame_by_frame: if np.isscalar(cx): cx = [cx]*nz cy = [cy]*nz if bpm_mask is not None: if bpm_mask.ndim == 2: bpm_mask = [bpm_mask]*n_frames bpm_mask = np.array(bpm_mask) if nproc == 1 or not shared_mem: for i in Progressbar(range(n_frames), desc="processing frames"): if bpm_mask is not None: bpm_mask_tmp = bpm_mask[i] else: bpm_mask_tmp = None if excl_mask is not None: excl_mask_tmp = excl_mask[i] else: excl_mask_tmp = None res = frame_fix_badpix_isolated(array[i], bpm_mask=bpm_mask_tmp, sigma_clip=sigma_clip, num_neig=num_neig, size=size, protect_mask=protect_mask, excl_mask=excl_mask_tmp, verbose=False, cxy=(cx[i], cy[i]), ignore_nan=ignore_nan, full_output=True) array_out[i] = res[0] final_bpm[i] = res[1] else: if verbose: print("Cleaning frames using ADACS' multiprocessing approach") # dummy calling the function to create cached version of the code prior to forking if bpm_mask is not None: bpm_mask_dum = bpm_mask[0] else: bpm_mask_dum = None if excl_mask is not None: excl_mask_dum = excl_mask[0] else: excl_mask_dum = None # point of dummy call frame_fix_badpix_isolated(array[0], bpm_mask=bpm_mask_dum, sigma_clip=sigma_clip, num_neig=num_neig, size=size, protect_mask=protect_mask, excl_mask=excl_mask_dum, verbose=False, cxy=(cx[0], cy[0]), ignore_nan=ignore_nan, full_output=False) # multiprocessing included only in the frame-by-frame branch of the # if statement above. # creating shared memory buffer for the cube (array) shm_arr = shared_memory.SharedMemory(create=True, size=array.nbytes) # creating a shared array_out version that is the shm_array_out # buffer above. sh_arr = np.ndarray(array.shape, dtype=array.dtype, buffer=shm_arr.buf) # creating shared memory buffer for the final bad pixel mask cube. shm_fbpm = shared_memory.SharedMemory(create=True, size=final_bpm.nbytes) # creating a shared final_bpm version that is in the shm_final_bpm # buffer above. sh_fbpm = np.ndarray(final_bpm.shape, dtype=final_bpm.dtype, buffer=shm_fbpm.buf) # function that calls frame_fix_badpix_isolated using the similar # arguments as in if nproc==1 branch above. def mp_clean_isolated(j, frame, bpm_mask=None, sigma_clip=3, num_neig=5, size=5, protect_mask=0, excl_mask=None, verbose=False, cxy=None, ignore_nan=True, full_output=True): sh_res = frame_fix_badpix_isolated(frame, bpm_mask, sigma_clip=sigma_clip, num_neig=num_neig, size=size, protect_mask=protect_mask, excl_mask=excl_mask, verbose=verbose, cxy=cxy, ignore_nan=ignore_nan, full_output=full_output) sh_arr[j], sh_fbpm[j] = sh_res # function that unwraps the arguments and passes them to # mp_clean_isolated. global _mp_clean_isolated def _mp_clean_isolated(args): pargs = args[0:2] kwargs = args[2] mp_clean_isolated(*pargs, **kwargs) context = multiprocessing.get_context('fork') pool = context.Pool(processes=nproc, maxtasksperchild=1) args = [] for j in range(n_frames): if bpm_mask is not None: bpm_mask_tmp = bpm_mask[j] else: bpm_mask_tmp = None if excl_mask is not None: excl_mask_tmp = excl_mask[j] else: excl_mask_tmp = None dict_kwargs = {'bpm_mask': bpm_mask_tmp, 'sigma_clip': sigma_clip, 'num_neig': num_neig, 'size': size, 'protect_mask': protect_mask, 'excl_mask': excl_mask_tmp, 'cxy': (cx[j], cy[j]), 'ignore_nan': ignore_nan} args.append([j, array[j], dict_kwargs]) try: pool.map_async(_mp_clean_isolated, args, chunksize=1).get(timeout=10_000_000) finally: pool.close() pool.join() array_out[:] = sh_arr[:] final_bpm[:] = sh_fbpm[:] shm_arr.close() shm_arr.unlink() shm_fbpm.close() shm_fbpm.unlink() count_bp = np.sum(final_bpm) else: if excl_mask is None: excl_mask = np.zeros(array.shape[-2:], dtype=bool) elif excl_mask.ndim == 3: excl_mask = np.median(excl_mask, axis=0) elif excl_mask is None: msg = "Input exclusion mask should have same last 2 dims as array\n" assert excl_mask.shape == array.shape[-2:], msg ind_excl = np.where(excl_mask) if bpm_mask is None or not correct_only: if bpm_mask is None: bpm_mask = np.zeros(array.shape[-2:], dtype=bool) elif bpm_mask.ndim == 3: bpm_mask = np.median(bpm_mask, axis=0) bpm_mask = bpm_mask+excl_mask ori_nan_mask = np.where(np.isnan(np.nanmean(array, axis=0))) ind = clip_array(np.nanmean(array, axis=0), sigma_clip, sigma_clip, bpm_mask, neighbor=neigh, num_neighbor=num_neig, mad=mad) final_bpm = np.zeros_like(array[0], dtype=bool) final_bpm[ind] = 1 if ignore_nan: final_bpm[ori_nan_mask] = 0 if protect_mask: cir = disk((cy, cx), protect_mask, shape=final_bpm.shape) final_bpm[cir] = 0 final_bpm[ind_excl] = 0 final_bpm = final_bpm.astype('bool') else: if bpm_mask.ndim == 3: final_bpm = np.median(bpm_mask, axis=0) else: final_bpm = bpm_mask.copy() for i in Progressbar(range(n_frames), desc="processing frames"): frame = array_out[i] smoothed = median_filter(frame, size, mode='mirror') frame[np.where(final_bpm)] = smoothed[np.where(final_bpm)] array_out[i] = frame if verbose: count_bp += np.sum(final_bpm) if verbose: msg = "Done replacing {:.0f} bad pixels using the median of neighbors" print(msg.format(count_bp)) if not frame_by_frame: msg = "(i.e. {:.0f} static bad pixels per channel))" print(msg.format(count_bp/n_frames)) timing(start) if full_output: return array_out, final_bpm else: return array_out
[docs] def cube_fix_badpix_annuli(array, fwhm, cy=None, cx=None, sig=5., bpm_mask=None, protect_mask=0, excl_mask=None, r_in_std=50, r_out_std=None, verbose=True, half_res_y=False, min_thr=None, max_thr=None, min_thr_np=None, bad_values=None, full_output=False): """ Function to correct the bad pixels annulus per annulus (centered on the provided location of the star), in an input frame or cube. This function is faster than ``cube_fix_badpix_clump``; hence to be preferred in all cases where there is only one bright source with circularly symmetric PSF. The bad pixel values are replaced by: ann_median + random_poisson; where ann_median is the median of the annulus, and random_poisson is random noise picked from a Poisson distribution centered on ann_median. Parameters ---------- array : 3D or 2D array Input 3d cube or 2d image. fwhm: float or 1D array Vector containing the full width half maximum of the PSF in pixels, for each channel (cube_like); or single value (frame_like) cy, cx : None, float or 1D array, optional If None: will use the barycentre of the image found by photutils.centroid_com() If floats: coordinates of the center, assumed to be the same in all frames if the input is a cube. If 1D arrays: they must be the same length as the 0th dimension of the input cube. sig: Float scalar, optional Number of stddev above or below the median of the pixels in the same annulus, to consider a pixel as bad. bpm_mask: 3D or 2D array, opt Input bad pixel array. If 2D and array is 3D: should have same last 2 dimensions as array. If 3D, should have exact same dimensions as input array. If not provided, the algorithm will attempt to identify bad pixel clumps automatically. protect_mask : int or float, optional If larger than 0, radius of a circular aperture (at the center of the frames) in which no bad pixels will be identified. This can be useful to protect the star and vicinity. excl_mask : numpy ndarray, optional Binary mask with 1 in areas that should not be considered as good neighbouring pixels during the identification of bad pixels. These should not be considered as bad pixels to be corrected neither (i.e. different to bpm_mask). r_in_std: float or None, optional Inner radius (in pixels) of the annulus used for the calculation of the standard deviation of the background noise - used as a min threshold when identifying bad pixels. Default: 50. r_out_std: float or None, optional Outer radius in pixels of the annulus used for the calculation of the standard deviation of the background noise - used as a min threshold when identifying bad pixels. If set to None, the default will be to consider the largest annulus starting at r_in_std which fits within the frame. verbose: bool, {False, True}, optional Whether to print out the number of bad pixels in each frame. half_res_y: bool, {True,False}, optional Whether the input data have only half the angular resolution vertically compared to horizontally (e.g. SINFONI data). The algorithm will then correct the bad pixels every other row. min_thr, max_thr: {None,float}, optional Any pixel whose value is lower (resp. larger) than this threshold will be automatically considered bad and hence sigma_filtered. If None, it is not used. min_thr_np: {None, float}, optional Any pixel whose value is lower than this threshold will be automatically considered bad and hence sigma_filtered, EVEN if located within the radius of protect_mask. bad_values: list or None, optional If not None, should correspond to a list of known bad values (e.g. 0). These pixels will be added to the input bad pixel map. full_output: bool, {False,True}, optional Whether to return as well the cube of bad pixel maps and the cube of defined annuli. Returns ------- array_corr: 2d or 3d array The bad pixel corrected frame/cube. bpix_map: 2d or 3d array [full_output=True] The bad pixel map or the cube of bpix maps ann_frame_cumul: 2 or 3d array [full_output=True] The cube of defined annuli """ ndims = array.ndim assert ndims == 2 or ndims == 3, "Object is not two or three dimensional.\n" # thresholds if min_thr is None: min_thr = np.amin(array)-1 if max_thr is None: max_thr = np.amax(array)-1 if bpm_mask is not None: msg = "Input bad pixel mask should have same last 2 dims as array\n" assert bpm_mask.shape[-2:] == array.shape[-2:], msg bpm_mask = bpm_mask.astype('bool') if bad_values is not None: if bpm_mask is None: bpm_mask = np.zeros(array.shape, dtype=bool) for bad in bad_values: bpm_mask[np.where(array == bad)] = 1 def bp_removal_2d(array, cy, cx, fwhm, sig, protect_mask, bpm_mask_ori, excl_mask, r_in_std, r_out_std, verbose): msg = "Input exclusion mask should have same shape as array\n" assert excl_mask.shape == array.shape, msg ind_excl = np.where(excl_mask) frame = array.copy() n_x = array.shape[1] n_y = array.shape[0] # Squash the frame if twice less resolved vertically than horizontally if half_res_y: if n_y % 2 != 0: msg = 'The input frames do not have of an even number of rows. ' msg2 = 'Hence, you should not use option half_res_y = True' raise ValueError(msg+msg2) n_y = int(n_y/2) cy = int(cy/2) array = np.zeros([n_y, n_x]) excl_mask_corr = np.zeros([n_y, n_x]) for yy in range(n_y): array[yy] = frame[2*yy] excl_mask_corr[yy] = excl_mask[2*yy] excl_mask = excl_mask_corr if bpm_mask_ori is not None: bpm_mask_tmp = np.zeros([n_y, n_x]) for yy in range(n_y): bpm_mask_tmp[yy] = bpm_mask_ori[2*yy] bpm_mask_ori = bpm_mask_tmp # 1/ Stddev of background if r_in_std or r_out_std: r_in_std = min(r_in_std*fwhm, cx-2, cy-2, n_x-cx-2, n_y-cy-2) if r_out_std: r_out_std *= fwhm else: r_out_std = min(n_y-(cy+r_in_std), cy-r_in_std, n_x-(cx+r_in_std), cx-r_in_std) width = max(2, r_out_std-r_in_std) array_crop = get_annulus_segments(array, r_in_std, width, mode="val") else: array_crop = array _, _, stddev = sigma_clipped_stats(array_crop, sigma=2.5) # 2/ Define each annulus, its median and stddev ymax = max(cy, n_y-cy) xmax = max(cx, n_x-cx) if half_res_y: ymax *= 2 rmax = np.sqrt(ymax**2+xmax**2) # the annuli definition is optimized for Airy rings ann_width = max(1.5, 0.5*fwhm) # 0.61*fwhm nrad = int(rmax/ann_width)+1 d_bord_max = max(n_y-cy, cy, n_x-cx, cx) if half_res_y: d_bord_max = max(2*(n_y-cy), 2*cy, n_x-cx, cx) big_ell_frame = np.zeros_like(array) sma_ell_frame = np.zeros_like(array) ann_frame_cumul = np.zeros_like(array) n_neig = np.zeros(nrad, dtype=np.int16) med_neig = np.zeros(nrad) std_neig = np.zeros(nrad) neighbours = np.zeros([nrad, n_y*n_x]) bpm_mask = excl_mask.copy() if bpm_mask_ori is not None: bpm_mask += bpm_mask_ori.astype(bool) if min_thr_np is not None: bpm_mask[np.where(array < min_thr_np)] = 1 ind_bad = np.where(bpm_mask) for rr in range(nrad): if rr > int(d_bord_max/ann_width): # just to merge farthest annuli with very few elements rr_big = nrad rr_sma = int(d_bord_max/ann_width) else: rr_big = rr rr_sma = rr if half_res_y: big_ell_idx = ellipse(r=cy, c=cx, r_radius=((rr_big+1)*ann_width)/2, c_radius=(rr_big+1)*ann_width, shape=(n_y, n_x)) if rr != 0: small_ell_idx = ellipse(r=cy, c=cx, r_radius=(rr_sma*ann_width)/2, c_radius=rr_sma*ann_width, shape=(n_y, n_x)) else: big_ell_idx = disk((cy, cx), radius=(rr_big+1)*ann_width, shape=(n_y, n_x)) if rr != 0: small_ell_idx = disk((cy, cx), radius=rr_sma*ann_width, shape=(n_y, n_x)) big_ell_frame[big_ell_idx] = 1 if rr != 0: sma_ell_frame[small_ell_idx] = 1 sma_ell_frame[ind_bad] = 1 ann_frame = big_ell_frame - sma_ell_frame n_neig[rr] = ann_frame[np.where(ann_frame)].shape[0] neighbours[rr, :n_neig[rr]] = array[np.where(ann_frame)] ann_frame_cumul[np.where(ann_frame)] = rr # We delete iteratively max and min outliers in each annulus, # so that the annuli median and stddev are not corrupted by bpixs neigh = neighbours[rr, :n_neig[rr]] n_rm = 0 n_pix_init = neigh.shape[0] while neigh.shape[0] >= np.amin(n_neig[rr]) and n_rm < n_pix_init/5: min_neigh = np.amin(neigh) if reject_outliers(neigh, min_neigh, m=5, stddev=stddev): min_idx = np.argmin(neigh) neigh = np.delete(neigh, min_idx) n_rm += 1 else: max_neigh = np.amax(neigh) if reject_outliers(neigh, max_neigh, m=5, stddev=stddev): max_idx = np.argmax(neigh) neigh = np.delete(neigh, max_idx) n_rm += 1 else: break n_neig[rr] = neigh.shape[0] neighbours[rr, :n_neig[rr]] = neigh neighbours[rr, n_neig[rr]:] = 0 med_neig[rr] = np.median(neigh) std_neig[rr] = np.std(neigh) # 3/ Create a tuple-array with coordinates of a circle of radius 1.8fwhm # centered on the provided coordinates of the star if protect_mask: if half_res_y: circl_new = ellipse(cy, cx, r_radius=protect_mask/2., c_radius=protect_mask, shape=(n_y, n_x)) else: circl_new = disk((cy, cx), radius=protect_mask, shape=(n_y, n_x)) else: circl_new = [] # 4/ Loop on all pixels to check bpix array_corr, bpix_map = correct_ann_outliers(array, bpm_mask, ann_width, sig, med_neig, std_neig, cy, cx, min_thr, max_thr, stddev, half_res_y) # 5/ Count bpix and uncorrect if within the circle nbpix_tot = int(np.sum(bpix_map)) nbpix_tbc = int(nbpix_tot - np.sum(bpix_map[circl_new])) if min_thr_np is not None: bp_tmp = np.zeros_like(bpix_map) bp_tmp[circl_new] = 1 cond1 = array >= min_thr_np cond2 = bp_tmp == 1 fin_mask = np.where(cond1 & cond2) bpix_map[fin_mask] = 0 array_corr[fin_mask] = array[fin_mask] else: bpix_map[circl_new] = 0 array_corr[circl_new] = array[circl_new] if verbose: print(nbpix_tot, ' bpix in total, and ', nbpix_tbc, ' corrected.') # Unsquash all the frames if half_res_y: frame = array_corr.copy() frame_bpix = bpix_map.copy() n_y = 2*n_y array_corr = np.zeros([n_y, n_x]) bpix_map = np.zeros([n_y, n_x]) ann_frame = ann_frame_cumul.copy() ann_frame_cumul = np.zeros([n_y, n_x]) for yy in range(n_y): array_corr[yy] = frame[int(yy/2)] bpix_map[yy] = frame_bpix[int(yy/2)] ann_frame_cumul[yy] = ann_frame[int(yy/2)] # Include + exclude relevant pixels array_corr[ind_excl] = frame[ind_excl] bpix_map[ind_excl] = 0 return array_corr, bpix_map, ann_frame_cumul if cy is None or cx is None: cy, cx = frame_center(array) if ndims == 2: if excl_mask is None: excl_mask = np.zeros(array.shape, dtype=bool) array_corr, bpix_map, ann_frame_cumul = bp_removal_2d(array, cy, cx, fwhm, sig, protect_mask, bpm_mask, excl_mask, r_in_std, r_out_std, verbose) if ndims == 3: array_corr = array.copy() n_z = array.shape[0] bpix_map = np.zeros_like(array) ann_frame_cumul = np.zeros_like(array) if np.isscalar(fwhm): fwhm = [fwhm]*n_z if np.isscalar(cx) and np.isscalar(cy): cy = [cy]*n_z cx = [cx]*n_z if bpm_mask is None: bpm_mask = np.zeros(array_corr.shape, dtype=bool) elif bpm_mask.ndim == 2: bpm_mask = np.array([bpm_mask]*n_z, dtype=bool) if excl_mask is None: excl_mask = np.zeros(array_corr.shape, dtype=bool) elif excl_mask.ndim == 2: excl_mask = np.array([excl_mask]*n_z, dtype=bool) for i in range(n_z): if verbose: print('************Frame # ', i, ' *************') print('centroid assumed at coords:', cx[i], cy[i]) res_i = bp_removal_2d(array[i], cy[i], cx[i], fwhm[i], sig, protect_mask, bpm_mask[i], excl_mask[i], r_in_std, r_out_std, verbose) array_corr[i], bpix_map[i], ann_frame_cumul[i] = res_i if full_output: return array_corr, bpix_map, ann_frame_cumul else: return array_corr
[docs] def cube_fix_badpix_clump(array, bpm_mask=None, correct_only=False, cy=None, cx=None, fwhm=4., sig=4., protect_mask=0, excl_mask=None, half_res_y=False, min_thr=None, max_nit=15, mad=True, bad_values=None, verbose=True, full_output=False, nproc=1): """ Function to identify and correct clumps of bad pixels. Very fast when a bad pixel map is provided. If a bad pixel map is not provided, the bad pixel clumps will be searched iteratively and replaced by the median of good neighbouring pixel values, when enough of them are available. The size of the box is set by the closest odd integer larger than fwhm (to avoid accidentally replacing point sources). Parameters ---------- array : 3D or 2D array Input 3d cube or 2d image. bpm_mask: 3D or 2D array, opt Input bad pixel array. If 2D and array is 3D: should have same last 2 dimensions as array. If 3D, should have exact same dimensions as input array. If not provided, the algorithm will attempt to identify bad pixel clumps automatically. correct_only : bool, opt If True and bpix_map is provided, will only correct for provided bad pixels. Else, the algorithm will determine (more) bad pixels. cy,cx : float or 1D array. opt Vector with approximate y and x coordinates of the star for each channel (cube_like), or single 2-elements vector (frame_like). These will be used if bpix_map is None and protect_mask set to True. If left to None, default values will correspond to the central pixel coordinates. fwhm: float or 1D array, opt Vector containing the full width half maximum of the PSF in pixels, for each channel (cube_like); or single value (frame_like). Should be provided if bpix map is None. sig: float, optional Value representing the number of "sigmas" above or below the "median" of the neighbouring pixel, to consider a pixel as bad. See details on parameter "m" of function reject_outlier. protect_mask : int or float, optional If larger than 0, radius of a circular aperture (at the center of the frames) in which no bad pixels will be identified. This can be useful to protect the star and vicinity. excl_mask : numpy ndarray, optional Binary mask with same dimensions as array, with 1 in areas that should not be considered as good neighbouring pixels during the identification of bad pixels. These should not be considered as bad pixels to be corrected neither (i.e. different to bpm_mask). half_res_y: bool, {True,False}, optional Whether the input data has only half the angular resolution vertically compared to horizontally (e.g. the case of SINFONI data); in other words there are always 2 rows of pixels with exactly the same values. The algorithm will just consider every other row (hence making it twice faster), then apply the bad pixel correction on all rows. min_thr: float, tuple or None, opt If a float is provided, corresponds to a minimum absolute threshold below which pixels are not considered bad (can be used to avoid the identification of bad pixels within noise). If a tuple of 2 values, corresponds to the range of values within which not to consider a pixel as bad. (e.g. (-0.1, 10.)). max_nit: float, optional Maximum number of iterations on a frame to correct bpix. Typically, it should be set to less than ny/2 or nx/2. This is a mean of precaution in case the algorithm gets stuck with 2 neighbouring pixels considered bpix alternately on two consecutively iterations hence leading to an infinite loop (very very rare case). mad : {False, True}, bool optional If True, the median absolute deviation will be used instead of the standard deviation. bad_values: list or None, optional If not None, should correspond to a list of known bad values (e.g. 0). These pixels will be added to the input bad pixel map. verbose: bool, {False,True}, optional Whether to print the number of bad pixels and number of iterations required for each frame. full_output: bool, {False,True}, optional Whether to return as well the cube of bad pixel maps and the cube of defined annuli. nproc: int, optional This feature is added following ADACS update. Refers to the number of processors available for calculations. Choosing a number >1 enables multiprocessing for the correction of frames. Returns ------- array_corr: 2d or 3d array The bad pixel corrected frame/cube. bpix_map: 2d or 3d array [full_output=True] The bad pixel map or the cube of bpix maps """ array_corr = array.copy() ndims = array_corr.ndim assert ndims == 2 or ndims == 3, "Object is not two or three dimensional.\n" if bpm_mask is not None: msg = "Input bad pixel mask should have same last 2 dims as array\n" assert bpm_mask.shape[-2:] == array.shape[-2:], msg bpm_mask = bpm_mask.astype('bool') if bad_values is not None: if bpm_mask is None: bpm_mask = np.zeros(array_corr.shape, dtype=bool) for bad in bad_values: bpm_mask[np.where(array_corr == bad)] = 1 if correct_only and bpm_mask is None: msg = "Bad pixel map should be provided if correct_only is True." raise ValueError(msg) def bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask_ori, excl_mask, min_thr, half_res_y, mad, verbose): msg = "Input exclusion mask should have same shape as array\n" assert excl_mask.shape == array_corr.shape, msg ind_excl = np.where(excl_mask) n_x = array_corr.shape[1] n_y = array_corr.shape[0] if half_res_y: if n_y % 2 != 0: msg = 'The input frames do not have of an even number of rows. ' msg2 = 'Hence, you should not use option half_res_y = True' raise ValueError(msg+msg2) n_y = int(n_y/2) frame = array_corr.copy() array_corr = np.zeros([n_y, n_x]) excl_mask_corr = np.zeros([n_y, n_x]) for yy in range(n_y): array_corr[yy] = frame[2*yy] excl_mask_corr[yy] = excl_mask[2*yy] excl_mask = excl_mask_corr if bpm_mask_ori is not None: bpm_mask_tmp = np.zeros([n_y, n_x]) for yy in range(n_y): bpm_mask_tmp[yy] = bpm_mask_ori[2*yy] bpm_mask_ori = bpm_mask_tmp fwhm_round = int(round(fwhm)) # This should reduce the chance to accidently correct a bright planet: if fwhm_round % 2 == 0: neighbor_box = max(3, fwhm_round+1) else: neighbor_box = max(3, fwhm_round) nneig = sum(np.arange(3, neighbor_box+2, 2)) # 1/ Create a tuple-array with coordinates of a circle of radius 1.8fwhm # centered on the approximate coordinates of the star if protect_mask: if half_res_y: circl_new = ellipse(int(cy/2), cx, r_radius=0.5*protect_mask, c_radius=protect_mask, shape=(n_y, n_x)) else: circl_new = disk((cy, cx), radius=protect_mask, shape=(n_y, n_x)) else: circl_new = [] # 3/ Create a bad pixel map, by detecting them with clip_array bpm_mask = excl_mask.copy() if bpm_mask_ori is not None: bpm_mask += bpm_mask_ori.astype(bool) bp = clip_array(array_corr, sig, sig, bpm_mask, out_good=False, neighbor=True, num_neighbor=neighbor_box, mad=mad, half_res_y=half_res_y) bpix_map = np.zeros_like(array_corr) bpix_map[bp] = 1 if min_thr is not None: if np.isscalar(min_thr): min_thr = (-min_thr, min_thr) elif not isinstance(min_thr, tuple): msg = "if provided, min_thr should be float or tuple" raise ValueError(msg) else: if len(min_thr) != 2: msg = "if min_thr is a tuple, it should have 2 elements" raise ValueError(msg) cond1 = array_corr > min_thr[0] cond2 = array_corr < min_thr[1] bpix_map[np.where(cond1 & cond2)] = 0 nbpix_tot = int(np.sum(bpix_map)) bpix_map[circl_new] = 0 bpix_map[ind_excl] = 0 nbpix_tbc = int(np.sum(bpix_map)) bpix_map_cumul = np.zeros(bpix_map.shape, dtype=bool) bpix_map_cumul[:] = bpix_map[:] nit = 0 # 4/ Loop over the bpix correction with sigma_filter, until 0 bpix left while nbpix_tbc > 0 and nit < max_nit: nit = nit+1 if verbose: msg = "Iteration {}: {} bad pixels identified".format(nit, nbpix_tot) # if bpm_mask_ori is not None: # nbpix_ori = np.sum(bpm_mask_ori) # msg += " ({} new ones)".format(nbpix_tot-nbpix_ori) if protect_mask: msg += ", {} to be corrected".format(nbpix_tbc) print(msg) array_corr = sigma_filter(array_corr, bpix_map, neighbor_box=neighbor_box, min_neighbors=nneig, half_res_y=half_res_y, verbose=verbose) bpm_mask = None # known bad ones are corrected above +=bpix_map_cumul bp = clip_array(array_corr, sig, sig, bpm_mask, out_good=False, neighbor=True, num_neighbor=neighbor_box, mad=mad, half_res_y=half_res_y) bpix_map = np.zeros(array_corr.shape, dtype=bool) bpix_map[bp] = 1 if min_thr is not None: cond1 = array_corr > min_thr[0] cond2 = array_corr < min_thr[1] bpix_map[np.where(cond1 & cond2)] = 0 nbpix_tot = int(np.sum(bpix_map)) bpix_map[circl_new] = 0 bpix_map[ind_excl] = 0 nbpix_tbc = int(np.sum(bpix_map)) bpix_map_cumul = bpix_map_cumul+bpix_map if verbose: print('All bad pixels are corrected.') if half_res_y: frame = array_corr.copy() frame_bpix = bpix_map_cumul.copy() n_y = 2*n_y array_corr = np.zeros([n_y, n_x]) bpix_map_cumul = np.zeros([n_y, n_x]) for yy in range(n_y): array_corr[yy] = frame[int(yy/2)] bpix_map_cumul[yy] = frame_bpix[int(yy/2)] return array_corr, bpix_map_cumul if ndims == 2: if bpm_mask is None or not correct_only: if (cy is None or cx is None) and protect_mask: cy, cx = frame_center(array) if excl_mask is None: excl_mask = np.zeros(array_corr.shape, dtype=bool) array_corr, bpix_map_cumul = bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask, excl_mask, min_thr, half_res_y, mad, verbose) else: fwhm_round = int(round(fwhm)) fwhm_round = fwhm_round+1-(fwhm_round % 2) # make it odd neighbor_box = max(3, fwhm_round) # to not replace a companion nneig = sum(np.arange(3, neighbor_box+2, 2)) array_corr = sigma_filter(array_corr, bpm_mask, neighbor_box, nneig, half_res_y, verbose) bpix_map_cumul = bpm_mask if ndims == 3: n_z = array_corr.shape[0] if bpm_mask is None or not correct_only: if bpm_mask is None: bpm_mask = np.zeros(array_corr.shape, dtype=bool) elif bpm_mask.ndim == 2: bpm_mask = np.array([bpm_mask]*n_z, dtype=bool) if excl_mask is None: excl_mask = np.zeros(array_corr.shape, dtype=bool) elif excl_mask.ndim == 2: excl_mask = np.array([excl_mask]*n_z, dtype=bool) if cy is None or cx is None: cy, cx = frame_center(array) cy = [cy]*n_z cx = [cx]*n_z elif np.isscalar(cy) and np.isscalar(cx): cy = [cy]*n_z cx = [cx]*n_z if np.isscalar(fwhm): fwhm = [fwhm]*n_z if nproc == 1 or not shared_mem: bpix_map_cumul = np.zeros_like(array_corr) for i in range(n_z): if verbose: print('************Frame # ', i, ' *************') res = bp_removal_2d(array_corr[i], cy[i], cx[i], fwhm[i], sig, protect_mask, bpm_mask[i], excl_mask[i], min_thr, half_res_y, mad, verbose) array_corr[i], bpix_map_cumul[i] = res else: if verbose: print("Cleaning frames using ADACS' multiprocessing approach") # creating shared memory buffer space for the image cube. shm_clump = shared_memory.SharedMemory(create=True, size=array_corr.nbytes) obj_tmp_shared_clump = np.ndarray(array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump.buf) # creating shared memory buffer space for the bad pixel cube. shm_clump_bpix = shared_memory.SharedMemory(create=True, size=array_corr.nbytes) # works with dtype=obj_tmp.dtype but not dtype=int bpix_map_cumul_shared = np.ndarray(array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump_bpix.buf) def mp_clump_slow(j, array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask, excl_mask, min_thr, half_res_y, mad, verbose): res = bp_removal_2d(array_corr, cy, cx, fwhm, sig, protect_mask, bpm_mask, excl_mask, min_thr, half_res_y, mad, verbose) obj_tmp_shared_clump[j], bpix_map_cumul_shared[j] = res global _mp_clump_slow def _mp_clump_slow(args): mp_clump_slow(*args) context = multiprocessing.get_context('fork') pool = context.Pool(processes=nproc, maxtasksperchild=1) args = [] for i in range(n_z): args.append([i, array_corr[i], cy[i], cx[i], fwhm[i], sig, protect_mask, bpm_mask[i], excl_mask[i], min_thr, half_res_y, mad, verbose]) try: pool.map_async(_mp_clump_slow, args, chunksize=1).get( timeout=10_000_000) finally: pool.close() pool.join() bpix_map_cumul = np.zeros_like(array_corr, dtype=array_corr.dtype) bpix_map_cumul[:] = bpix_map_cumul_shared[:] array_corr[:] = obj_tmp_shared_clump[:] shm_clump.close() shm_clump.unlink() shm_clump_bpix.close() shm_clump_bpix.unlink() else: if np.isscalar(fwhm): fwhm_round = int(round(fwhm)) else: fwhm_round = int(np.median(fwhm)) fwhm_round = fwhm_round+1-(fwhm_round % 2) # make it odd neighbor_box = max(3, fwhm_round) # to not replace a companion nneig = sum(np.arange(3, neighbor_box+2, 2)) if nproc == 1 or not shared_mem: for i in range(n_z): if verbose: print('Using serial approach') print('************Frame # ', i, ' *************') if bpm_mask.ndim == 3: bpm = bpm_mask[i] else: bpm = bpm_mask array_corr[i] = sigma_filter(array_corr[i], bpm, neighbor_box, nneig, half_res_y, verbose) else: if verbose: print("Cleaning frames using ADACS' multiprocessing approach") # dummy calling sigma_filter function to create a cached version of the numba function if bpm_mask.ndim == 3: dummy_bpm = bpm_mask[0] else: dummy_bpm = bpm_mask # Actual dummy call is here. sigma_filter(array_corr[0], dummy_bpm, neighbor_box, nneig, half_res_y, verbose) # creating shared memory that each process writes into. shm_clump = shared_memory.SharedMemory( create=True, size=array_corr.nbytes) # creating an array that uses shared memory buffer and has the properties of array_corr. obj_tmp_shared_clump = np.ndarray( array_corr.shape, dtype=array_corr.dtype, buffer=shm_clump.buf) # function that is called repeatedly by each process. def mp_clean_clump(j, array_corr, bpm, neighbor_box, nneig, half_res_y, verbose): obj_tmp_shared_clump[j] = sigma_filter(array_corr, bpm, neighbor_box, nneig, half_res_y, verbose) global _mp_clean_clump # function that converts the args into bite-sized pieces for mp_clean_clump. def _mp_clean_clump(args): mp_clean_clump(*args) context = multiprocessing.get_context('fork') pool = context.Pool(processes=nproc, maxtasksperchild=1) args = [] for j in range(n_z): if bpm_mask.ndim == 3: bpm = bpm_mask[j] else: bpm = bpm_mask args.append([j, array_corr[j], bpm, neighbor_box, nneig, half_res_y, verbose]) try: pool.map_async(_mp_clean_clump, args, chunksize=1).get( timeout=10_000_000) finally: pool.close() pool.join() array_corr[:] = obj_tmp_shared_clump[:] shm_clump.close() shm_clump.unlink() bpix_map_cumul = bpm_mask # make it a binary map bpix_map_cumul[np.where(bpix_map_cumul > 1)] = 1 if full_output: return array_corr, bpix_map_cumul else: return array_corr
[docs] def cube_fix_badpix_ifs(array, lbdas, fluxes=None, mask=None, cy=None, cx=None, clumps=True, sigma_clip=3, num_neig=5, size=5, protect_mask=0, mad=False, fwhm=4, min_thr=None, max_nit=15, imlib='vip-fft', interpolation='lanczos4', ignore_nan=True, verbose=True, full_output=False): """ Function to identify and correct bad pixels in an IFS cube, leveraging on the radial expansion of the PSF with wavelength. Bad pixel identification is done with either the `cube_fix_badpix_isolated` or the `cube_fix_badpix_clump` function in PSF subtracted frames (through SDI). Parameters ---------- array : 3D or 4D array Input 3D (spectral) or 4D (spectral+temporal) cube. In the latter case, dimensions should be spectral x temporal x vertical x horizontal. lbdas: 1d array or list Vector with the wavelengths, used for first guess on scaling factor. fluxes: 1d array or list, optional Vector with the (unsaturated) fluxes at the different wavelengths, used for first guess on flux factor. mask: 2D-array, opt Binary mask, with ones where the residual intensities should be evaluated. If None is provided, the whole field is used. cy, cx : None, float or 1D array, optional If None: will use the barycentre of the image found by photutils.centroid_com() If floats: coordinates of the center, assumed to be the same in all frames if the input is a cube. If 1D arrays: they must be the same length as the 0th dimension of the input cube. clumps: bool, optional Whether to use `cube_fix_badpix_clump` (True) or `cube_fix_badpix_isolated` (False) in the SDI residual cube. sigma_clip : int, optional In case no bad pixel mask is provided all the pixels above and below sigma_clip*STDDEV will be marked as bad. num_neig : int, optional The side of the square window around each pixel where the sigma clipped statistics are calculated (STDDEV and MEDIAN). If the value is equal to 0 then the statistics are computed in the whole frame. size : odd int, optional The size the box (size x size) of adjacent pixels for the median filter. protect_mask : int or float, optional If larger than 0, radius of a circular aperture (at the center of the frames) in which no bad pixels will be identified. This can be useful to protect the star and vicinity. mad : {False, True}, bool optional If True, the median absolute deviation will be used instead of the standard deviation. fwhm: float or 1D array, opt Vector containing the full width half maximum of the PSF in pixels, for each channel (cube_like); or single value (frame_like). Shouod be provided if bpix map is None. min_thr: float, tuple or None, opt If a float is provided, corresponds to a minimum absolute threshold below which pixels are not considered bad in the residua images (can be used to avoid the identification of bad pixels within noise). If a tuple of 2 values, corresponds to the range of values within which not to consider a pixel as bad, in the residual images (e.g. (-1, 5)). max_nit: float, optional Maximum number of iterations on a frame to correct bpix. Typically, it should be set to less than ny/2 or nx/2. This is a mean of precaution in case the algorithm gets stuck with 2 neighbouring pixels considered bpix alternately on two consecutively iterations hence leading to an infinite loop (very very rare case). ignore_nan: bool, optional Whether to not consider NaN values as bad pixels. If False, will also correct them. verbose : bool, optional If True additional information will be printed. full_output: bool, {False,True}, optional Whether to return as well the cube of bad pixel maps and the cube of defined annuli. Return ------ array_out : numpy ndarray Cube with bad pixels corrected. bpm_mask: 2d or 3d array [if full_output is True] The bad pixel map or the cube of bpix maps array_res: 2d or 3d array [if full_output is True] SDI-residual cube in which bad pixels are identified """ def _res_scaled_images(array, lbdas, fluxes, mask, cy, cx): if fluxes is None: fluxes = [1]*len(lbdas) if cx is None or cy is None: ref_xy = None else: ref_xy = (cx, cy) scal_vec, flux_vec = find_scal_vector(array, lbdas, fluxes, mask=mask, nfp=2, fm="sum", imlib=imlib, interpolation=interpolation) res_array = np.zeros_like(array) for z in range(array.shape[0]): other_ch = [i for i in range(array.shape[0]) if i != z] res_arr_tmp = [] for zp in other_ch: flux_scal = flux_vec[zp]/flux_vec[z] resc_fr = frame_rescaling(flux_scal*array[zp], ref_xy=ref_xy, scale=scal_vec[zp]/scal_vec[z], imlib=imlib, interpolation=interpolation) res_arr_tmp.append(array[z]-resc_fr) res_arr_tmp = np.array(res_arr_tmp) res_array[z] = np.median(res_arr_tmp, axis=0) return res_array cube = array.copy() ndims = cube.ndim if cy is None or cx is None: cxy = None else: cy, cx = frame_center(cube) cxy = (cx, cy) if ndims == 3: array_res = _res_scaled_images(cube, lbdas, fluxes, mask, cy, cx) # bad pixel identification in residual cube if clumps: _, final_bpm = cube_fix_badpix_clump(array_res, bpm_mask=None, cy=cy, cx=cx, fwhm=fwhm, sig=sigma_clip, protect_mask=protect_mask, verbose=verbose, min_thr=min_thr, max_nit=max_nit, mad=mad, full_output=True) else: _, final_bpm = cube_fix_badpix_isolated(array_res, bpm_mask=None, sigma_clip=sigma_clip, num_neig=num_neig, size=size, frame_by_frame=True, protect_mask=protect_mask, cxy=cxy, mad=mad, ignore_nan=ignore_nan, verbose=verbose, full_output=True) final_bpm[np.where(final_bpm > 1)] = 1 # bad pixel correction in original cube array_out = cube_fix_badpix_isolated(cube, bpm_mask=final_bpm, sigma_clip=sigma_clip, num_neig=num_neig, size=size, frame_by_frame=True, protect_mask=protect_mask, cxy=cxy, mad=mad, ignore_nan=ignore_nan, verbose=verbose, full_output=False) elif ndims == 4: n_z = cube.shape[1] array_out = np.zeros_like(cube) array_res = np.zeros_like(cube) final_bpm = np.zeros_like(cube) if np.isscalar(cy) and np.isscalar(cx): cy = [cy]*n_z cx = [cx]*n_z for i in range(n_z): if verbose: print('************ Cube #{}/{} *************'.format(i+1, n_z)) array_res[:, i] = _res_scaled_images(cube[:, i], lbdas, fluxes, mask, cy, cx) # bad pixel identification in residual cube if clumps: res = cube_fix_badpix_clump(array_res[:, i], bpm_mask=None, cy=cy, cx=cx, fwhm=fwhm, sig=sigma_clip, protect_mask=protect_mask, verbose=verbose, min_thr=min_thr, max_nit=max_nit, mad=mad, full_output=True) else: res = cube_fix_badpix_isolated(array_res[:, i], bpm_mask=None, sigma_clip=sigma_clip, num_neig=num_neig, size=size, frame_by_frame=True, protect_mask=protect_mask, cxy=cxy, mad=mad, ignore_nan=ignore_nan, verbose=verbose, full_output=True) _, final_bpm[:, i] = res final_bpm[np.where(final_bpm > 1)] = 1 # bad pixel correction in original cube array_out[:, i] = cube_fix_badpix_isolated(cube[:, i], final_bpm[:, i], correct_only=False, sigma_clip=sigma_clip, num_neig=num_neig, size=size, frame_by_frame=True, protect_mask=protect_mask, cxy=cxy, mad=mad, ignore_nan=ignore_nan, verbose=verbose, full_output=False) else: raise TypeError("Input cube should be 3d or 4d") if full_output: return array_out, final_bpm, array_res else: return array_out
[docs] def cube_fix_badpix_interp(array, bpm_mask, mode='fft', excl_mask=None, fwhm=4., kernel_sz=None, psf=None, half_res_y=False, nit=500, tol=1, nproc=1, full_output=False, **kwargs): """ Function to correct clumps of bad pixels by interpolation with either a user-defined kernel (through astropy.convolution) or through the FFT-based algorithm described in [AAC01]_. A bad pixel map must be provided (e.g. found with function `cube_fix_badpix_clump`). Parameters ---------- array : 3D or 2D array Input 3d cube or 2d image. bpm_mask: 3D or 2D array Input bad pixel array. Should have same x,y dimensions as array. If 2D, but input array is 3D, the same bpix_map will be assumed for all frames. mode: str, optional {'fft', 'gauss', 'psf'} Can be either a 2D Gaussian ('gauss') or an input normalized PSF ('psf'). excl_mask: 3D or 2D array, optional [Only used if mode != 'fft'] Input exclusion mask array. Pixels in the exclusion mask will neither be used for interpolation, nor replaced as bad pixels. excl_mask should have same x,y dimensions as array. If 2D, but input array is 3D, the same exclusion mask will be assumed for all frames. fwhm: float, 1D array or tuple of 2 floats, opt If mode is 'gauss', the fwhm of the Gaussian. kernel_sz: int or None, optional Size of the kernel in pixels for 2D Gaussian and Moffat convolutions. If None, astropy.convolution will automatically consider 8*fwhm kernel sizes. psf: 2D or 3D array, optional If mode is 'psf', a normalized PSF array. If a 3D cube is provided (e.g. for spectral cubes), the first dimension should match that of the input array (which should also be 3D). Else, the same 2D PSF kernel will be for all input frames, whether the input is 2D or 3D. If half_res_y is True, psf should be provided vertically squashed. half_res_y: bool, {True,False}, optional Whether the input data has only half the angular resolution vertically compared to horizontally (e.g. the case for some IFUs); in other words there are always 2 rows of pixels with exactly the same values. If so, the Gaussian kernel will also be squashed vertically by a factor 2. nit: int or list of int, optional For FFT-based iterative interpolation, the number of iterations to use. If a list is provided, a list of bad pixel corrected frames/cubes is returned. tol: float Tolerance in terms of E_g (see [AAC01]_). The iterative process is stopped if the error E_g gets lower than this tolerance. nproc : None or int, optional Number of processes for parallel computing. If None the number of processes will be set to (cpu_count()/2). Note: only used for input 3D cube and 'fft' mode. full_output: bool, {False,True}, optional In the case of FT-based interpolation, whether to return as well the reconstructed images. **kwargs : dict Passed through to the astropy.convolution.convolve or convolve_fft function. Returns ------- array_corr: 2d or 3d array; The bad pixel corrected frame/cube. recon_cube: 2d or 3d array; [full_output=True & mode='fft'] The reconstructed frame/cube. If nit is a list, a list of reconstructed frames/cubes is returned. """ array_corr = array.copy() ndims = array_corr.ndim assert ndims == 2 or ndims == 3, "Object is not two or three dimensional.\n" if bpm_mask.shape[-2:] != array.shape[-2:]: raise TypeError("Bad pixel map has wrong y/x dimensions.") if excl_mask is None: excl_mask = np.zeros(array.shape, dtype=bool) elif excl_mask.ndim == 2 and array.ndim == 3: nz = array.shape[0] excl_mask = np.array([excl_mask]*nz, dtype=bool) msg = "Input exclusion mask should have same shape as array\n" assert excl_mask.shape[-2:] == array.shape[-2:], msg if np.sum(bpm_mask) == 0: msg = "Warning: no bad pixel found in bad pixel map. " msg += "Returning input array as is." print(msg) return array ny, nx = array.shape[-2:] if ndims == 3: nz = array.shape[0] if bpm_mask.ndim == 2 and ndims == 3: master_bpm = np.zeros([nz, ny, nx]) for z in range(nz): master_bpm[z] = bpm_mask[np.newaxis, :, :] bpm_mask = master_bpm.copy() if half_res_y: # squash vertically def squash_v(array): ny, nx = array.shape if ny % 2: raise ValueError("Input array y dimension should be even") nny = ny//2 new_array = np.zeros([nny, nx]) for y in range(nny): new_array[y] = array[int(y*2)] return new_array if ndims == 2: array_corr = squash_v(array) bpm_mask = squash_v(bpm_mask) bpm_mask = squash_v(excl_mask) else: new_array_corr = [] new_bpm_mask = [] new_excl_mask = [] for z in range(nz): new_array_corr.append(squash_v(array_corr[z])) new_bpm_mask.append(squash_v(bpm_mask[z])) new_excl_mask.append(squash_v(excl_mask[z])) array_corr = np.array(new_array_corr) bpm_mask = np.array(new_bpm_mask) excl_mask = np.array(new_excl_mask) if mode != 'fft': # first replace all bad pixels with NaNs - they will be interpolated array_corr[np.where(bpm_mask+excl_mask)] = np.nan if ndims == 2: array_corr_filt = frame_filter_lowpass(array_corr, mode=mode, fwhm_size=fwhm, conv_mode='conv', kernel_sz=kernel_sz, psf=psf, iterate=True, half_res_y=half_res_y, **kwargs) else: array_corr_filt = array_corr.copy() if np.isscalar(fwhm): fwhm = [fwhm]*nz elif len(fwhm) == 2 and len(fwhm) != nz: fwhm = [fwhm]*nz if psf is None: psf = [psf]*nz elif psf.ndim == 2: psf = [psf]*nz elif psf.shape[0] != nz: raise ValueError( "input psf must have same z dimension as array") for z in range(nz): array_corr_filt[z] = frame_filter_lowpass(array_corr[z], mode=mode, fwhm_size=fwhm[z], conv_mode='conv', kernel_sz=kernel_sz, psf=psf[z], iterate=True, half_res_y=half_res_y, **kwargs) # replace only the bad pixels (array_corr is low-pass filtered) array_corr[np.where(bpm_mask)] = array_corr_filt[np.where(bpm_mask)] else: if ndims == 2: res = frame_fix_badpix_fft(array_corr, bpm_mask, nit=nit, tol=tol, full_output=full_output) if isinstance(nit, int): array_corr = res else: array_corr, recon_cube = res else: if bpm_mask.ndim == 2: bpm_mask = [bpm_mask]*nz if nproc is None: nproc = cpu_count()//2 res = pool_map(nproc, frame_fix_badpix_fft, iterable(array_corr), iterable(bpm_mask), nit, tol, 2, False, full_output, msg="Correcting bad pixels") if full_output and isinstance(nit, int): array_corr = np.array(res[:, 0], dtype=np.float64) recon_cube = np.array(res[:, 1], dtype=np.float64) elif full_output: nz = array_corr.shape[0] nnit = len(nit) tmp = res[:, 0] tmp2 = res[:, 1] array_corr = [] recon_cube = [] for j in range(nnit): tmp_list = [tmp[i][j] for i in range(nz)] array_corr.append(np.array(tmp_list)) tmp_list = [tmp2[i][j] for i in range(nz)] recon_cube.append(np.array(tmp_list)) else: array_corr = np.array(res, dtype=np.float64) if half_res_y: # unsquash vertically def unsquash_v(array): ny, nx = array.shape nny = int(ny*2) new_array_corr = np.zeros([nny, nx]) for y in range(nny): new_array_corr[y] = array[y//2] return new_array_corr if ndims == 2: array_corr = unsquash_v(array_corr) else: new_array_corr = [] for z in range(nz): new_array_corr.append(unsquash_v(array_corr[z])) array_corr = np.array(new_array_corr) if mode == 'fft' and full_output: return array_corr, recon_cube else: return array_corr
def find_outliers(frame, sig_dist, in_bpix=None, stddev=None, neighbor_box=3, min_thr=None, mid_thr=None): """ Provides a bad pixel (or outlier) map for a given frame. Parameters ---------- frame: 2d array Input 2d image. sig_dist: float Threshold used to discriminate good from bad neighbours, in terms of normalized distance to the median value of the set (see reject_outliers) in_bpix: 2d array, optional Input bpix map (typically known from the previous iteration), to only look for bpix around those locations. neighbor_box: int, optional The side of the square window around each pixel where the sigma and median are calculated for the bad pixel DETECTION and CORRECTION. min_thr: {None,float}, optional Any pixel whose value is lower than this threshold (expressed in adu) will be automatically considered bad and hence sigma_filtered. If None, it is not used. mid_thr: {None, float}, optional Pixels whose value is lower than this threshold (expressed in adu) will have its neighbours checked; if there is at max. 1 neighbour pixel whose value is lower than mid_thr+(5*stddev), then the pixel is considered bad (because it means it is a cold pixel in the middle of significant signal). If None, it is not used. Returns ------- bpix_map : numpy ndarray Output cube with frames indicating the location of bad pixels""" ndims = len(frame.shape) assert ndims == 2, "Object is not two dimensional.\n" nx = frame.shape[1] ny = frame.shape[0] bpix_map = np.zeros_like(frame) if stddev is None: stddev = np.std(frame) half_box = int(neighbor_box/2) if in_bpix is None: for xx in range(nx): for yy in range(ny): # 0/ Determine the box of neighbouring pixels # half size of the box at the bottom of the pixel hbox_b = min(half_box, yy) # half size of the box at the top of the pixel hbox_t = min(half_box, ny-1-yy) # half size of the box to the left of the pixel hbox_l = min(half_box, xx) # half size of the box to the right of the pixel hbox_r = min(half_box, nx-1-xx) # but in case we are at an edge, we want to extend the box by # one row/column of px in the direction opposite to the edge: if yy > ny-1-half_box: hbox_b = hbox_b + (yy-(ny-1-half_box)) elif yy < half_box: hbox_t = hbox_t+(half_box-yy) if xx > nx-1-half_box: hbox_l = hbox_l + (xx-(nx-1-half_box)) elif xx < half_box: hbox_r = hbox_r+(half_box-xx) # 1/ list neighbouring pixels, >8 (NOT including pixel itself) neighbours = frame[yy-hbox_b:yy+hbox_t+1, xx-hbox_l:xx+hbox_r+1] idx_px = ([[hbox_b], [hbox_l]]) flat_idx = np.ravel_multi_index(idx_px, (hbox_t+hbox_b+1, hbox_r+hbox_l+1)) neighbours = np.delete(neighbours, flat_idx) # 2/ Det if central pixel is outlier test_result = reject_outliers(neighbours, frame[yy, xx], m=sig_dist, stddev=stddev, min_thr=min_thr, mid_thr=mid_thr) # 3/ Assign the value of the test to bpix_map bpix_map[yy, xx] = test_result else: nb = int(np.sum(in_bpix)) # number of bad pixels at previous iteration wb = np.where(in_bpix) # pixels to check bool_bpix = np.zeros_like(in_bpix) for n in range(nb): for yy in [max(0, wb[0][n]-half_box), wb[0][n], min(ny-1, wb[0][n]+half_box)]: for xx in [max(0, wb[1][n]-half_box), wb[1][n], min(ny-1, wb[1][n]+half_box)]: bool_bpix[yy, xx] = 1 nb = int(np.sum(bool_bpix)) # true number of px to check (including # neighbours of bpix from previous iteration) wb = np.where(bool_bpix) # true px to check for n in range(nb): # 0/ Determine the box of neighbouring pixels # half size of the box at the bottom of the pixel hbox_b = min(half_box, wb[0][n]) # half size of the box at the top of the pixel hbox_t = min(half_box, ny-1-wb[0][n]) # half size of the box to the left of the pixel hbox_l = min(half_box, wb[1][n]) # half size of the box to the right of the pixel hbox_r = min(half_box, nx-1-wb[1][n]) # but in case we are at an edge, we want to extend the box by one # row/column of pixels in the direction opposite to the edge: if wb[0][n] > ny-1-half_box: hbox_b = hbox_b + (wb[0][n]-(ny-1-half_box)) elif wb[0][n] < half_box: hbox_t = hbox_t+(half_box-wb[0][n]) if wb[1][n] > nx-1-half_box: hbox_l = hbox_l + (wb[1][n]-(nx-1-half_box)) elif wb[1][n] < half_box: hbox_r = hbox_r+(half_box-wb[1][n]) # 1/ list neighbouring pixels, > 8, not including the pixel itself neighbours = frame[wb[0][n]-hbox_b:wb[0][n]+hbox_t+1, wb[1][n]-hbox_l:wb[1][n]+hbox_r+1] c_idx_px = ([[hbox_b], [hbox_l]]) flat_c_idx = np.ravel_multi_index(c_idx_px, (hbox_t+hbox_b+1, hbox_r+hbox_l+1)) neighbours = np.delete(neighbours, flat_c_idx) # 2/ test if bpix test_result = reject_outliers(neighbours, frame[wb[0][n], wb[1][n]], m=sig_dist, stddev=stddev, min_thr=min_thr, mid_thr=mid_thr) # 3/ Assign the value of the test to bpix_map bpix_map[wb[0][n], wb[1][n]] = test_result return bpix_map def reject_outliers(data, test_value, m=5., stddev=None, debug=False): """ Function to reject outliers from a set. Instead of the classic standard deviation criterion (e.g. 5-sigma), the discriminant is determined as follow: - for each value in data, an absolute distance to the median of data is computed and put in a new array "d" (of same size as data) - scaling each element of "d" by the median value of "d" gives the absolute distances "s" of each element - each "s" is then compared to "m" (parameter): if s < m, we have a good neighbour, otherwise we have an outlier. A specific value test_value is tested as outlier. Parameters: ----------- data: numpy ndarray Input array with respect to which either a test_value or the central a value of data is determined to be an outlier or not test_value: float Value to be tested as an outlier in the context of the input array data m: float, optional Criterion used to test if test_value is or pixels of data are outlier(s) (similar to the number of "sigma" in std_dev statistics) stddev: float, optional (but strongly recommended) Global std dev of the non-PSF part of the considered frame. It is needed as a reference to know the typical variation of the noise, and hence avoid detecting outliers out of very close pixel values. If the 9 pixels of data happen to be very uniform in values at some location, the departure in value of only one pixel could make it appear as a bad pixel. If stddev is not provided, the stddev of data is used (not recommended). Returns ------- test_result: 0 or 1 0 if test_value is not an outlier. 1 otherwise. """ if no_numba: def _reject_outliers(data, test_value, m=5., stddev=None, debug=False): if stddev is None: stddev = np.std(data) med = np.median(data) d = np.abs(data - med) mdev = np.median(d) if debug: print("data = ", data) print("median(data)= ", np.median(data)) print("d = ", d) print("mdev = ", mdev) print("stddev(box) = ", np.std(data)) print("stddev(frame) = ", stddev) print("max(d) = ", np.max(d)) if max(np.max(d), np.abs(test_value-med)) > stddev: mdev = mdev if mdev > stddev else stddev s = d/mdev if debug: print("s =", s) test = np.abs((test_value-np.median(data))/mdev) if debug: print("test =", test) else: if test < m: test_result = 0 else: test_result = 1 else: test_result = 0 return test_result return _reject_outliers(data, test_value, m=m, stddev=stddev, debug=debug) else: @njit def _reject_outliers(data, test_value, m=5., stddev=None): if stddev is None: stddev = np.std(data) med = np.median(data) d = data.copy() d_flat = d.flatten() for i in range(d_flat.shape[0]): d_flat[i] = np.abs(data.flatten()[i] - med) mdev = np.median(d_flat) if max(np.max(d), np.abs(test_value-med)) > stddev: test = np.abs((test_value-med)/mdev) if test < m: test_result = 0 else: test_result = 1 else: test_result = 0 return test_result return _reject_outliers(data, test_value, m=m, stddev=stddev) def correct_ann_outliers(array, bpix_map, ann_width, sig, med_neig, std_neig, cy, cx, min_thr, max_thr, stddev, half_res_y=False): """ Function to correct outliers in concentric annuli. Parameters: ----------- array: numpy ndarray Input array with respect to which either a test_value or the central value of data is determined to be an outlier or not bpix_map: numpy ndarray or None Input array of known bad pixels. ann_width: float Width of concenrtric annuli in pixels. sig: float Number of sigma to consider a pixel intensity as an outlier. med_neig, std_neig: 1d arrays Median and standard deviation of good pixel intensities in each annulus cy, cx: floats Coordinates of the center of the concentric annuli. min_thr, max_thr: {None,float} Any pixel whose value is lower (resp. larger) than this threshold will be automatically considered bad and hence sigma_filtered. If None, it is not used. stddev: float Global std dev of the non-PSF part of the considered frame. It is needed as a reference to know the typical variation of the noise, and hence avoid detecting outliers out of very close pixel values. If the 9 pixels of data happen to be very uniform in values at some location, the departure in value of only one pixel could make it appear as a bad pixel. If stddev is not provided, the stddev of data is used (not recommended). half_res_y: bool, {True,False}, optional Whether the input data have only half the angular resolution vertically compared to horizontally (e.g. SINFONI data). The algorithm will then correct the bad pixels every other row. Returns ------- array_corr: np.array Array with corrected outliers. bpix_map: np.array Boolean array with location of outliers. """ if no_numba: def _correct_ann_outliers(array, bpix_map, ann_width, sig, med_neig, std_neig, cy, cx, min_thr, max_thr, stddev, half_res_y=False): n_y, n_x = array.shape rand_arr = 2*(np.random.rand(n_y, n_x)-0.5) array_corr = array.copy() if bpix_map is None: bpix_map = np.zeros([n_y, n_x]) for yy in range(n_y): for xx in range(n_x): if half_res_y: rad = np.sqrt((2*(cy-yy))**2+(cx-xx)**2) else: rad = np.sqrt((cy-yy)**2+(cx-xx)**2) rr = int(rad/ann_width) dev = max(stddev, min(std_neig[rr], med_neig[rr])) # check min_thr if array[yy, xx] < min_thr: bpix_map[yy, xx] = 1 array_corr[yy, xx] = med_neig[rr] + \ np.sqrt(np.abs(med_neig[rr]))*rand_arr[yy, xx] # check max_thr elif array[yy, xx] > max_thr: bpix_map[yy, xx] = 1 array_corr[yy, xx] = med_neig[rr] + \ np.sqrt(np.abs(med_neig[rr]))*rand_arr[yy, xx] elif (array[yy, xx] < med_neig[rr]-sig*dev or array[yy, xx] > med_neig[rr]+sig*dev): bpix_map[yy, xx] = 1 if bpix_map[yy, xx]: array_corr[yy, xx] = med_neig[rr] + \ np.sqrt(np.abs(med_neig[rr]))*rand_arr[yy, xx] return array_corr, bpix_map else: @njit def _correct_ann_outliers(array, bpix_map, ann_width, sig, med_neig, std_neig, cy, cx, min_thr, max_thr, stddev, half_res_y=False): n_y, n_x = array.shape rand_arr = 2*(np.random.rand(n_y, n_x)-0.5) array_corr = array.copy() if bpix_map is None: bpix_map = np.zeros([n_y, n_x]) for yy in range(n_y): for xx in range(n_x): if half_res_y: rad = np.sqrt((2*(cy-yy))**2+(cx-xx)**2) else: rad = np.sqrt((cy-yy)**2+(cx-xx)**2) rr = int(rad/ann_width) dev = max(stddev, min(std_neig[rr], med_neig[rr])) # check min_thr if array[yy, xx] < min_thr: bpix_map[yy, xx] = 1 # check max_thr elif array[yy, xx] > max_thr: bpix_map[yy, xx] = 1 elif (array[yy, xx] < med_neig[rr]-sig*dev or array[yy, xx] > med_neig[rr]+sig*dev): bpix_map[yy, xx] = 1 if bpix_map[yy, xx]: array_corr[yy, xx] = med_neig[rr] + \ np.sqrt(np.abs(med_neig[rr]))*rand_arr[yy, xx] return array_corr, bpix_map return _correct_ann_outliers(array, bpix_map, ann_width, sig, med_neig, std_neig, cy, cx, min_thr, max_thr, stddev, half_res_y=half_res_y)
[docs] def frame_fix_badpix_fft(array, bpm_mask, nit=500, tol=1, pad_fac=2, verbose=True, full_output=False): """ Function to interpolate bad pixels with the FFT-based algorithm in [AAC01]_. Parameters ---------- array : 2D ndarray Input image. bpm_mask : 2D ndarray Bad pixel map. nit: int or list of int, optional The number of iterations to use. If a list is provided, a list of bad pixel corrected frames/cubes is returned. tol: float, opt Tolerance in terms of E_g (see [AAC01]_). The iterative process is stopped if the error E_g gets lower than this tolerance. pad_fac: int or float, opt Padding factor before calculating 2D-FFT. verbose: bool Whether to print additional information during processing, incl. progress bar. full_output: bool Whether to also return the reconstructed estimate f_hat of the input array. Returns ------- array_corr: 2D ndarray or list of 2D ndarray Image in which the bad pixels have been interpolated. f_est: 2D ndarray or list of 2D ndarray [full_output=True] Reconstructed estimate (f_hat in [AAC01]_) of the input array """ if array.ndim != 2: raise TypeError("Input array should be 2D") if array.shape != bpm_mask.shape: raise TypeError("Input bad pixel map should have same shape as array") if isinstance(nit, list): nit_max = max(nit) return_list = True else: nit_max = nit return_list = False final_array_corr = [] final_f_est = [] # Pad zeros for better results ini_y, ini_x = array.shape pad_fac = (int(pad_fac*ini_x/ini_y), pad_fac) g = frame_pad(array, pad_fac, keep_parity=False, fillwith=0) w = frame_pad(1-bpm_mask, pad_fac, keep_parity=False, fillwith=0) # Following AAC01 notations: g *= w if verbose: start = time_ini() G_i = np.fft.fft2(g) W = np.fft.fft2(w) # Initialisation dims = g.shape ny, nx = dims npix = float(ny * nx) F_est = np.zeros(dims, dtype=complex) it = 0 Eg = tol + 1 for it in Progressbar(range(nit_max), desc="iterative bad pixel correction"): # 1. select line as max(G_i) and infer conjugate coordinates ind = np.unravel_index(np.argmax(np.abs(G_i.real[:, 0: nx // 2])), (ny, nx // 2)) ind_conj = (np.mod(ny - ind[0], ny), np.mod(nx - ind[1], nx)) # 2. compute the new F_i # handle cases with no conjugate: cond1 = (ind[0] == 0) and (ind[1] == 0) cond2 = (ind[0] == ny / 2) and (ind[1] == 0) cond3 = (ind[0] == 0) and (ind[1] == nx / 2) cond4 = (ind[0] == ny / 2) and (ind[1] == nx / 2) if cond1 or cond2 or cond3 or cond4: F_i = npix * G_i[ind] / W[(0, 0)] # 3a. update F_est F_est[ind] += F_i # handle cases where conjugate indices exist else: a = np.power(np.abs(W[(0, 0)]), 2) b = np.power(np.abs(W[(2 * ind[0]) % ny, (2 * ind[1]) % nx]), 2) Wmin = np.amin(np.abs(W)) if a == b: # avoid later division by 0 W[(2 * ind[0]) % ny, (2 * ind[1]) % nx] += Wmin*1e-11 a = np.power(np.abs(W[(0, 0)]), 2) b = np.power(np.abs(W[(2 * ind[0]) % ny, (2 * ind[1]) % nx]), 2) c = a - b F_i = (npix/c) * (G_i[ind] * W[(0, 0)] - np.conj(G_i[ind]) * W[(2 * ind[0]) % ny, (2 * ind[1]) % nx]) # 3b. update F_est F_est[ind] += F_i F_est[ind_conj] += np.conj(F_i) # 4. get the new error spectrum G_i = get_err_spec(F_i, W, ind, npix, G_i, dims) # 5. Calculate new error - to check if still larger than tolerance Eg = np.sum(np.power(np.abs(G_i.flatten()), 2))/npix if (return_list and it in nit) or (it == nit_max-1) or (Eg < tol): array_corr = g + np.fft.ifft2(F_est).real * (1 - w) # crop zeros to return to initial size cy, cx = frame_center(array_corr) wy = (ini_y - 1) / 2 wx = (ini_x - 1) / 2 y0 = int(cy - wy) y1 = int( cy + wy + 1) # +1 cause endpoint is excluded when slicing x0 = int(cx - wx) x1 = int(cx + wx + 1) final_array_corr.append(array_corr[y0:y1, x0:x1]) # Calculate reconstructed image f_est = np.fft.ifft2(F_est).real final_f_est.append(f_est[y0:y1, x0:x1]) if Eg < tol: break if verbose: msg = "FFT-interpolation terminated after {} iterations (Eg={})" print(msg.format(it+1, Eg)) if verbose: timing(start) if not return_list: final_array_corr = final_array_corr[-1] final_f_est = final_f_est[-1] if full_output: return final_array_corr, final_f_est else: return final_array_corr
def get_err_spec(F_i, W, ind, npix, G_i, dims): def _err_spec(F_i, W, ind, npix, G_i, dims): ny, nx = dims conv = np.zeros(dims, dtype=np.complex64) cond1 = (ind[0] == 0) and (ind[1] == 0) cond2 = (ind[0] == ny / 2) and (ind[1] == 0) cond3 = (ind[0] == 0) and (ind[1] == nx / 2) cond4 = (ind[0] == ny / 2) and (ind[1] == nx / 2) # cases with no conjugate: if cond1 or cond2 or cond3 or cond4: for y in range(ny): for x in range(nx): conv[y, x] = F_i * W[y - ind[0], x - ind[1]] else: for y in range(ny): for x in range(nx): conv[y, x] = (F_i * W[y - ind[0], x - ind[1]] + np.conj(F_i) * W[(y + ind[0]) % ny, (x + ind[1]) % nx]) return G_i - conv/float(npix) if not no_numba: _err_spec_numba = njit(_err_spec) return _err_spec_numba(F_i, W, ind, npix, G_i, dims) else: return _err_spec(F_i, W, ind, npix, G_i, dims)