"""
ROC curves generation.
"""
__all__ = ['EvalRoc',
'compute_binary_map']
import copy
import numpy as np
import matplotlib.pyplot as plt
from hciplot import plot_frames
from scipy import stats
from photutils.segmentation import detect_sources
from munch import Munch
from ..config import time_ini, timing, Progressbar
from ..fm import cube_inject_companions
from ..psfsub.svd import SVDecomposer
from ..var import frame_center, get_annulus_segments, get_circle
# TODO: remove the munch dependency
[docs]
class EvalRoc(object):
"""
Class for the generation of receiver operating characteristic (ROC) curves.
"""
# COLOR_1 = "#d62728" # CADI
# COLOR_2 = "#ff7f0e" # PCA
# COLOR_3 = "#2ca02c" # LLSG
# COLOR_4 = "#9467bd" # SODIRF
# COLOR_5 = "#1f77b4" # SODINN
# SYMBOL_1 = "^" # CADI
# SYMBOL_2 = "X" # PCA
# SYMBOL_3 = "P" # LLSG
# SYMBOL_4 = "s" # SODIRF
# SYMBOL_5 = "p" # SODINN
# # For model PSF subtraction algos that rely on a S/N map
# THRESHOLDS_05_5 = [0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]
# # For algos that output a likelihood or probability map
# THRESHOLDS_01_099 = np.linspace(0.1, 0.99, 10).tolist()
def __init__(self, dataset, plsc=0.0272, n_injections=100, inrad=8,
outrad=12, dist_flux=("uniform", 2, 500), mask=None):
"""
[...]
dist_flux : tuple ('method', *args)
'method' can be a string, e.g:
("skewnormal", skew, mean, var)
("uniform", low, high)
("normal", loc, scale)
or a function.
[...]
"""
self.dataset = dataset
self.plsc = plsc
self.n_injections = n_injections
self.inrad = inrad
self.outrad = outrad
self.dist_flux = dist_flux
self.mask = mask
self.methods = []
[docs]
def add_algo(self, name, algo, color, symbol, thresholds):
"""
Parameters
----------
algo : HciPostProcAlgo
thresholds : list of lists
"""
self.methods.append(Munch(algo=algo, name=name, color=color,
symbol=symbol, thresholds=thresholds))
[docs]
def inject_and_postprocess(self, patch_size, cevr=0.9,
expvar_mode='annular', nproc=1):
# """
# Notes
# -----
# # TODO `methods` are not returned inside `results` and are *not* saved!
# # TODO order of parameters for `skewnormal` `dist_flux` changed! (was [3], [1], [2])
# # TODO `save` not implemented
# """
from .. import hci_postproc
starttime = time_ini()
# ===== number of PCs for PCA / rank for LLSG
if cevr is not None:
svdecomp = SVDecomposer(self.dataset.cube, mode=expvar_mode,
inrad=self.inrad, outrad=self.outrad,
svd_mode='lapack', verbose=False)
_ = svdecomp.get_cevr(ncomp_list=None)
ratio_cumsum = svdecomp.cevr
self.optpcs = np.searchsorted(ratio_cumsum, cevr) + 1
print("{}% of CEVR with {} PCs".format(cevr, self.optpcs))
# for m in methods:
# if hasattr(m, "ncomp") and m.ncomp is None: # PCA
# m.ncomp = self.optpcs
#
# if hasattr(m, "rank") and m.rank is None: # LLSG
# m.rank = self.optpcs
#
# ------> this should be moved inside the HCIPostProcAlgo classes!
#
# Getting indices in annulus
width = self.outrad - self.inrad
yy, xx = get_annulus_segments(self.dataset.cube[0], self.inrad,
width)[0]
num_patches = yy.shape[0]
# Defining Fluxes according to chosen distribution
dist_fkt = dict(skewnormal=stats.skewnorm.rvs,
normal=np.random.normal,
uniform=np.random.uniform).get(self.dist_flux[0],
self.dist_flux[0])
self.fluxes = dist_fkt(*self.dist_flux[1:], size=self.n_injections)
self.fluxes.sort()
inds_inj = np.random.randint(0, num_patches, size=self.n_injections)
self.dists = []
self.thetas = []
for m in range(self.n_injections):
injx = xx[inds_inj[m]]
injy = yy[inds_inj[m]]
injx -= frame_center(self.dataset.cube[0])[1]
injy -= frame_center(self.dataset.cube[0])[0]
dist = np.sqrt(injx**2 + injy**2)
theta = np.mod(np.arctan2(injy, injx) / np.pi * 180, 360)
self.dists.append(dist)
self.thetas.append(theta)
for m in self.methods:
m.frames = []
m.probmaps = []
self.list_xy = []
# Injections
for n in Progressbar(range(self.n_injections), desc="injecting"):
cufc, cox, coy = _create_synt_cube(self.dataset.cube,
self.dataset.psf,
self.dataset.angles, self.plsc,
theta=self.thetas[n],
flux=self.fluxes[n],
dist=self.dists[n],
verbose=False)
cox = int(np.round(cox))
coy = int(np.round(coy))
self.list_xy.append((cox, coy))
for m in self.methods:
# TODO: this is not elegant at all.
# shallow copy. Should not copy e.g. the cube in memory,
# just reference it.
algo = copy.copy(m.algo)
_dataset = copy.copy(self.dataset)
_dataset.cube = cufc
if isinstance(algo, hci_postproc.HCIPca):
algo.ncomp = self.optpcs
# elif isinstance(algo, hci_postproc.HCILLSG):
# algo.rank = self.optpcs
algo.run(dataset=_dataset, verbose=False)
algo.make_snrmap(approximated=True, nproc=nproc, verbose=False)
m.frames.append(algo.frame_final)
m.probmaps.append(algo.snr_map)
timing(starttime)
[docs]
def compute_tpr_fps(self, **kwargs):
"""
Calculate number of dets/fps for every injection/method/threshold.
Take the probability maps and the desired thresholds for every method,
and calculates the binary map, number of detections and FPS using
``compute_binary_map``. Sets each methods ``detections``, ``fps`` and
``bmaps`` attributes.
Parameters
----------
**kwargs : keyword arguments
Passed to ``compute_binary_map``
"""
starttime = time_ini()
for m in self.methods:
m.detections = []
m.fps = []
m.bmaps = []
print('Evaluating injections:')
for i in Progressbar(range(self.n_injections)):
x, y = self.list_xy[i]
for m in self.methods:
dets, fps, bmaps = compute_binary_map(
m.probmaps[i], m.thresholds, fwhm=self.dataset.fwhm,
injections=(x, y), **kwargs
)
m.detections.append(dets)
m.fps.append(fps)
m.bmaps.append(bmaps)
timing(starttime)
[docs]
def plot_detmaps(self, i=None, thr=9, dpi=100,
axis=True, grid=False, vmin=-10, vmax='max',
plot_type="horiz"):
"""
Plot the detection maps for one injection.
Parameters
----------
i : int or None, optional
Index of the injection, between 0 and self.n_injections. If None,
takes the 30st injection, or if there are less injections, the
middle one.
thr : int, optional
Index of the threshold.
dpi, axis, grid, vmin, vmax
Passed to ``pp_subplots``
plot_type : {"horiz" or "vert"}, optional
Plot type.
``horiz``
One row per algorithm (frame, probmap, binmap)
``vert``
1 row for final frames, 1 row for probmaps and 1 row for binmaps
"""
# input parameters
if i is None:
if len(self.list_xy) > 30:
i = 30
else:
i = len(self.list_xy) // 2
if vmax == 'max':
# TODO: document this feature.
vmax = np.concatenate([m.frames[i] for m in self.methods if
hasattr(m, "frames") and
len(m.frames) >= i]).max()/2
# print information
print('X,Y: {}'.format(self.list_xy[i]))
print('dist: {:.3f}, flux: {:.3f}'.format(self.dists[i],
self.fluxes[i]))
print()
if plot_type in [1, "horiz"]:
for m in self.methods:
print('detection state: {} | false postives: {}'.format(
m.detections[i][thr], m.fps[i][thr]))
labels = ('{} frame'.format(m.name), '{} S/Nmap'.format(m.name),
'Thresholded at {:.1f}'.format(m.thresholds[thr]))
plot_frames((m.frames[i] if len(m.frames) >= i else
np.zeros((2, 2)), m.probmaps[i], m.bmaps[i][thr]),
label=labels, dpi=dpi, horsp=0.2, axis=axis,
grid=grid, cmap=['viridis', 'viridis', 'gray'])
elif plot_type in [2, "vert"]:
labels = tuple('{} frame'.format(m.name) for m in self.methods if
hasattr(m, "frames") and len(m.frames) >= i)
plot_frames(tuple(m.frames[i] for m in self.methods if
hasattr(m, "frames") and len(m.frames) >= i),
dpi=dpi, label=labels, vmax=vmax, vmin=vmin, axis=axis,
grid=grid)
plot_frames(tuple(m.probmaps[i] for m in self.methods), dpi=dpi,
label=tuple(['{} S/Nmap'.format(m.name) for m in
self.methods]), axis=axis, grid=grid)
for m in self.methods:
msg = '{} detection: {}, FPs: {}'
print(msg.format(m.name, m.detections[i][thr], m.fps[i][thr]))
labels = tuple('Thresholded at {:.1f}'.format(m.thresholds[thr])
for m in self.methods)
plot_frames(tuple(m.bmaps[i][thr] for m in self.methods),
dpi=dpi, label=labels, axis=axis, grid=grid,
colorbar=False, cmap='bone')
else:
raise ValueError("`plot_type` unknown")
[docs]
def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None,
ymin=-0.05, ymax=1.02, xlog=True, label_skip_one=False,
legend_loc='lower right', legend_size=6,
show_data_labels=True, hide_overlap_label=True,
label_gap=(0, -0.028), save_plot=False, label_params={},
line_params={}, marker_params={}, verbose=True):
# """
# Parameters
# ----------
# Returns
# -------
# None, but modifies `methods`: adds .tpr and .mean_fps attributes
# Notes
# -----
# # TODO: load `roc_injections` and `roc_tprfps` from file (`load_res`)
# # TODO: print flux distro information (is it actually stored in inj?
# What to do with functions, do they pickle?)
# # TODO: hardcoded `methodconf`?
# """
labelskw = dict(alpha=1, fontsize=5.5, weight="bold", rotation=0,
annotation_clip=True)
linekw = dict(alpha=0.2)
markerkw = dict(alpha=0.5, ms=3)
labelskw.update(label_params)
linekw.update(line_params)
markerkw.update(marker_params)
n_thresholds = len(self.methods[0].thresholds)
if verbose:
print('{} injections'.format(self.n_injections))
# print('Flux distro : {} [{}:{}]'.format(roc_injections.flux_distribution,
# roc_injections.fluxp1, roc_injections.fluxp2))
print('Annulus from {} to {} pixels'.format(self.inrad,
self.outrad))
fig = plt.figure(figsize=figsize, dpi=dpi)
ax = fig.add_subplot(111)
if not isinstance(label_skip_one, (list, tuple)):
label_skip_one = [label_skip_one]*len(self.methods)
labels = []
# methodconf = {"CADI": dict(color="#d62728", symbol="^"),
# "PCA": dict(color="#ff7f0e", symbol="X"),
# "LLSG": dict(color="#2ca02c", symbol="P"),
# "SODIRF": dict(color="#9467bd", symbol="s"),
# "SODINN": dict(color="#1f77b4", symbol="p"),
# "SODINN-pw": dict(color="#1f77b4", symbol="p")
# } # maps m.name to plot style
for i, m in enumerate(self.methods):
if not hasattr(m, "detections") or not hasattr(m, "fps"):
raise AttributeError("method #{} has no detections/fps. Run"
"`compute_tpr_fps` first.".format(i))
m.tpr = np.zeros((n_thresholds))
m.mean_fps = np.zeros((n_thresholds))
for j in range(n_thresholds):
m.tpr[j] = np.asarray(m.detections)[:, j].tolist().count(1) / \
self.n_injections
m.mean_fps[j] = np.asarray(m.fps)[:, j].mean()
plt.plot(m.mean_fps, m.tpr, '--', color=m.color, **linekw)
plt.plot(m.mean_fps, m.tpr, m.symbol, label=m.name, color=m.color,
**markerkw)
if show_data_labels:
if label_skip_one[i]:
lab_x = m.mean_fps[1::2]
lab_y = m.tpr[1::2]
thr = m.thresholds[1::2]
else:
lab_x = m.mean_fps
lab_y = m.tpr
thr = m.thresholds
for i, xy in enumerate(zip(lab_x + label_gap[0],
lab_y + label_gap[1])):
labels.append(ax.annotate('{:.2f}'.format(thr[i]),
xy=xy, xycoords='data', color=m.color,
**labelskw))
# TODO: reverse order of `self.methods` for better annot.
# z-index?
plt.legend(loc=legend_loc, prop={'size': legend_size})
if xlog:
ax.set_xscale("symlog")
plt.ylim(ymin=ymin, ymax=ymax)
plt.xlim(xmin=xmin, xmax=xmax)
plt.ylabel('TPR')
plt.xlabel('Full-frame mean FPs')
plt.grid(alpha=0.4)
if show_data_labels:
mask = np.zeros(fig.canvas.get_width_height(), bool)
fig.canvas.draw()
for label in labels:
bbox = label.get_window_extent()
negpad = -2
x0 = int(bbox.x0) + negpad
x1 = int(np.ceil(bbox.x1)) + negpad
y0 = int(bbox.y0) + negpad
y1 = int(np.ceil(bbox.y1)) + negpad
s = np.s_[x0:x1, y0:y1]
if np.any(mask[s]):
if hide_overlap_label:
label.set_visible(False)
else:
mask[s] = True
if save_plot:
if isinstance(save_plot, str):
plt.savefig(save_plot, dpi=dpi, bbox_inches='tight')
else:
plt.savefig('roc_curve.pdf', dpi=dpi, bbox_inches='tight')
[docs]
def compute_binary_map(frame, thresholds, injections, fwhm, npix=1,
overlap_threshold=0.7, max_blob_fact=2, plot=False,
debug=False):
"""
Take a list of ``thresholds``, create binary maps and counts detections/fps.
A blob which is "too big" is split into apertures, and every aperture adds
one 'false positive'.
Parameters
----------
frame : numpy ndarray
Detection map.
thresholds : list or numpy ndarray
List of thresholds (detection criteria).
injections : tuple, list of tuples
Coordinates (x,y) of the injected companions. Also accepts 1d/2d
ndarrays.
fwhm : float
FWHM, used for obtaining the size of the circular aperture centered at
the injection position (and measuring the overlapping with found blobs).
The circular aperture has 2 * FWHM in diameter.
npix : int, optional
The number of connected pixels, each greater than the given threshold,
that an object must have to be detected. ``npix`` must be a positive
integer. Passed to ``detect_sources`` function from ``photutils``.
overlap_threshold : float
Percentage of overlap a blob has to have with the aperture around an
injection.
max_blob_fact : float
Maximum size of a blob (in multiples of the resolution element) before
it is considered as "too big" (= non-detection).
plot : bool, optional
If True, a final resulting plot summarizing the results will be shown.
debug : bool, optional
For showing optional information.
Returns
-------
list_detections : list of int
List of detection count for each threshold.
list_fps : list of int
List of false positives count for each threshold.
list_binmaps : list of 2d ndarray
List of binary maps: detection maps thresholded for each threshold
value.
"""
def _overlap_injection_blob(injection, fwhm, blob_mask):
"""
Parameters
----------
injection: tuple (y,x)
fwhm : float
blob_mask : 2d bool ndarray
Returns
-------
overlap_fact : float between 0 and 1
Percentage of the area overlap. If the blob is smaller than the
resolution element, this is ``intersection_area / blob_area``,
otherwise ``intersection_area / resolution_element``.
"""
if len(injections[0]) > 0:
injection_mask = get_circle(np.ones_like(blob_mask), radius=fwhm,
cy=injection[1], cx=injection[0],
mode="mask")
else:
injection_mask = np.zeros_like(blob_mask)
intersection = injection_mask & blob_mask
smallest_area = min(blob_mask.sum(), injection_mask.sum())
return intersection.sum() / smallest_area
# --------------------------------------------------------------------------
list_detections = []
list_fps = []
list_binmaps = []
sizey, sizex = frame.shape
cy, cx = frame_center(frame)
reselem_mask = get_circle(frame, radius=fwhm, cy=cy, cx=cx, mode="val")
npix_circ_aperture = reselem_mask.shape[0]
# normalize injections: accepts combinations of 1d/2d and tuple/list/array.
injections = np.asarray(injections)
if injections.ndim == 1:
injections = np.array([injections])
for ithr, threshold in enumerate(thresholds):
if debug:
print("\nprocessing threshold #{}: {}".format(ithr + 1, threshold))
segments = detect_sources(frame, threshold, npix, connectivity=4)
# required since photutils v0.7
if segments is None:
binmap = np.zeros_like(frame)
detections = 0
fps = 0
if debug:
print("done with threshold #{}".format(ithr))
print("result: {} detections, {} fps".format(detections, fps))
list_detections.append(detections)
list_binmaps.append(binmap)
list_fps.append(fps)
continue
binmap = (segments.data != 0)
if debug:
plot_frames((segments.data, binmap), cmap=('tab20b', 'binary'),
circle=tuple(tuple(xy) for xy in injections),
circle_radius=fwhm, circle_alpha=0.6,
label=("segmentation map", "binary map"))
detections = 0
fps = 0
for segment in segments.segments:
label = segment.label
blob_mask = (segments.data == label)
blob_area = segment.area
if debug:
lab = "blob #{}, area={}px**2".format(label, blob_area)
plot_frames(blob_mask, circle_radius=fwhm, circle_alpha=0.6,
circle=tuple(tuple(xy) for xy in injections),
cmap='binary', label_size=8, label=lab,
size_factor=3)
for iinj, injection in enumerate(injections):
if len(injections[0]) > 0: # checking injections is not empty
if injection[0] > sizex or injection[1] > sizey:
raise ValueError("Wrong coordinates in `injections`")
if debug:
print("\ttesting injection #{} at {}".format(iinj + 1,
injection))
if blob_area > max_blob_fact * npix_circ_aperture:
number_of_apertures_in_blob = blob_area / npix_circ_aperture
fps += number_of_apertures_in_blob # float, rounded at end
if debug:
print("\tblob is too big (+{:.0f} fps)"
"".format(number_of_apertures_in_blob))
print("\tskipping all other injections")
# continue with next blob, do not check other injections
break
overlap = _overlap_injection_blob(injection, fwhm, blob_mask)
if overlap > overlap_threshold:
if debug:
print("\toverlap of {}! (+1 detection)"
"".format(overlap))
detections += 1
# continue with next blob, do not check other injections
break
if debug:
print("\toverlap of {} -> do nothing".format(overlap))
else:
if debug:
print("\tdid not find a matching injection for this "
"blob (+1 fps)")
fps += 1
if debug:
print("done with threshold #{}".format(ithr))
print("result: {} detections, {} fps".format(detections, fps))
fps = np.round(fps).astype(int).item() # -> python `int`
list_detections.append(detections)
list_binmaps.append(binmap)
list_fps.append(fps)
if plot:
labs = tuple(str(det) + ' detections' + '\n' + str(fps) +
' false positives' for det, fps in zip(list_detections,
list_fps))
if len(injections[0]) > 0:
circles = tuple(tuple(xy) for xy in injections)
else:
circles = None
plot_frames(tuple(list_binmaps), title='Final binary maps', label=labs,
label_size=8, cmap='binary', circle_alpha=0.8,
circle=circles, circle_radius=fwhm,
circle_color='deepskyblue', axis=False)
return list_detections, list_fps, list_binmaps
def _create_synt_cube(cube, psf, ang, plsc, dist, flux, theta=None,
verbose=False):
"""
"""
centy_fr, centx_fr = frame_center(cube[0])
if theta is None:
np.random.seed()
theta = np.random.randint(0, 360)
posy = dist * np.sin(np.deg2rad(theta)) + centy_fr
posx = dist * np.cos(np.deg2rad(theta)) + centx_fr
if verbose:
print('Theta:', theta)
print('Flux_inj:', flux)
cubefc = cube_inject_companions(cube, psf, ang, flevel=flux, plsc=plsc,
rad_dists=[dist], n_branches=1, theta=theta,
verbose=verbose)
return cubefc, posx, posy