Source code for vip_hci.config.utils_conf

#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""
Module with utilities.
"""

__author__ = "Carlos Alberto Gomez Gonzalez, Ralf Farkas"
__all__ = ["Progressbar", "check_array", "sep", "vip_figsize", "vip_figdpi"]

import os
import sys
import numpy as np

import itertools as itt
from inspect import signature, Parameter
from functools import wraps
import multiprocessing
from vip_hci import __version__

sep = "―" * 80
vip_figsize = (8, 5)
vip_figdpi = 100


def print_precision(array, precision=3):
    """Prints an array with a given floating point precision. 3 by default."""
    return print(np.array2string(array.astype(float), precision=precision))


class SaveableEmpty(object):
    """
    Empty object. Used by ``Saveable`` to restore the state of an object without
    calling __init__. Similar to what pickle/copy do.
    """

    pass


class Saveable(object):
    def save(self, filename):
        """
        Save a VIP object to a npz file.


        """

        vip_object = self.__class__.__name__

        if hasattr(self, "_saved_attributes"):
            data = {}

            for a in self._saved_attributes:
                if hasattr(self, a):
                    data[a] = getattr(self, a)

                    # set marker to re-build the original datatype
                    # (for non-np types like float, string, ...)
                    if not isinstance(getattr(self, a), np.ndarray):
                        data["_item_{}".format(a)] = True

                np.savez_compressed(
                    filename, _vip_version=__version__, _vip_object=vip_object, **data
                )

        else:
            raise RuntimeError(
                "_saved_attributes not found for class {}" "".format(vip_object)
            )

    @classmethod
    def load(cls, filename):
        try:
            data = np.load(filename, allow_pickle=True)
        except BaseException:
            data = np.load(filename + ".npz", allow_pickle=True)

        if "_vip_object" not in data:
            raise RuntimeError("The file you specified is not a VIP object.")

        file_vip_object = data["_vip_object"].item()
        if file_vip_object != cls.__name__:
            raise RuntimeError(
                "The object in the file is of type '{}', please "
                "use that classes 'load()' method instead."
                "".format(file_vip_object)
            )

        file_vip_version = data["_vip_version"].item()
        if file_vip_version != __version__:
            print(
                "The file was saved with VIP {}. There may be some"
                "compatibility issues. Use with care."
                "".format(file_vip_version)
            )

        self = SaveableEmpty()
        self.__class__ = cls

        for k in data:
            if k.startswith("_"):
                continue

            if "_item_{}".format(k) in data:
                setattr(self, k, data[k].item())  # un-pack np array
            else:
                setattr(self, k, data[k])

        # add non-saved, but expected attributes (backwards compatibility)
        for exp_k in self._saved_attributes:
            if exp_k not in data:
                setattr(self, exp_k, None)

        return self


[docs] class Progressbar(object): """Show progress bars. Supports multiple backends. Examples -------- .. code:: python from vip_hci.var import Progressbar Progressbar.backend = "tqdm" from time import sleep for i in Progressbar(range(50)): sleep(0.02) # or: bar = Progressbar(total=50): for i in range(50): sleep(0.02) bar.update() # Progressbar can be disabled globally using Progressbar.backend = "hide" # or locally using the ``verbose`` keyword: Progressbar(iterable, verbose=False) """ backend = "pyprind" def __new__( cls, iterable=None, desc=None, total=None, leave=True, backend=None, verbose=True, ): if backend is None: backend = Progressbar.backend if not verbose: backend = "hide" if backend == "tqdm": from tqdm import tqdm return tqdm( iterable=iterable, desc=desc, total=total, leave=leave, ascii=True, ncols=80, file=sys.stdout, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed" "}<{remaining}{postfix}]", ) # remove rate_fmt elif backend == "tqdm_notebook": from tqdm import tqdm_notebook return tqdm_notebook(iterable=iterable, desc=desc, total=total, leave=leave) elif backend == "pyprind": from pyprind import ProgBar, prog_bar ProgBar._adjust_width = lambda self: None # keep constant width if iterable is None: return ProgBar(total, title=desc, stream=1) else: return prog_bar(iterable, title=desc, stream=1, iterations=total) elif backend == "hide": return NoProgressbar(iterable=iterable) else: raise NotImplementedError("unknown backend")
[docs] def set(b): Progressbar.backend = b
class NoProgressbar(object): """Wraps an ``iterable`` to behave like ``Progressbar``, but without producing output. """ def __init__(self, iterable=None): self.iterable = iterable def __iter__(self): return self.iterable.__iter__() def __next__(self): return self.iterable.__next__() def __getattr__(self, key): return self.iterable.key def update(self): pass def algo_calculates_decorator(*calculated_attributes): """ Decorator for PostProc methods, describe what they calculate. There are three benefits from decorating a method: - if ``verbose=True``, prints a message about the calculated attributes and the ones which can be calculated next. - the attributes which *can* be calculated by this method are tracked, so if a user tries to access them *before* the function is called, an informative error message can be shown - the object knows which attributes to reset when ``run()`` is called a second time, on a different dataset. Parameters ---------- *calculated_attributes : list of strings Strings denominating the attributes the decorated function calculates. Examples -------- .. code:: python from .conf import algo_calculates_decorator as calculates class HCIMyAlgo(HCIPostPRocAlgo): def __init__(self, my_algo_param): self.store_args(locals()) @calculates("final_frame", "snr_map") def run(dataset=None, verbose=True): frame, snr = my_heavy_calculation() self.final_frame = frame self.snr_map = snr """ def decorator(fkt): @wraps(fkt) def wrapper(self, *args, **kwargs): # run the actual method res = fkt(self, *args, **kwargs) # get the kwargs the fkt sees. Note that this is a combination of # the *default* kwargs and the kwargs *passed* by the user sig = signature(fkt) params = sig.parameters all_kwargs = { k: v.default for k, v in params.items() if v.default is not Parameter.empty } all_kwargs.update(kwargs) if not hasattr(self, "_called_calculators"): self._called_calculators = [] self._called_calculators.append(fkt.__name__) # show help message if all_kwargs.get("verbose", False): self._show_attribute_help(fkt.__name__) return res # set an attribute on the wrapper so _get_calculations() can find it: wrapper._calculates = calculated_attributes return wrapper return decorator
[docs] def check_array(input_array, dim, msg=None): """Checks the dimensionality of input. In case the check is not successful, a TypeError is raised. Parameters ---------- input_array : list, tuple or np.ndarray Input data. dim : int or tuple Number of dimensions that ``input_array`` should have. ``dim`` can take one of these values: 1, 2, 3, 4, (1,2), (2,3), (3,4) or (2,3,4). msg : str, optional String to be used in the error message (``input_array`` name). """ if not isinstance(input_array, (list, tuple, np.ndarray)): raise TypeError("`input_array` must be a list, tuple of numpy ndarray") if msg is None: msg = "Input array" else: msg = "`" + msg + "`" error_msg = "`dim` must be: 1, 2, 3, 4, (1,2), (2,3), (3,4) or (2,3,4)" if isinstance(dim, int): if dim < 1 or dim > 4: raise ValueError(error_msg) elif isinstance(dim, tuple): if dim not in ((1, 2), (2, 3), (3, 4), (2, 3, 4)): raise ValueError(error_msg) msg2 = " must be a " msg3 = "d numpy ndarray" if dim == 1: if isinstance(input_array, (list, tuple)): input_array = np.array(input_array) if not isinstance(input_array, np.ndarray): raise TypeError(msg + msg2 + "list, tuple or a 1" + msg3) if not input_array.ndim == dim: raise TypeError(msg + msg2 + "list, tuple or a 1" + msg3) elif dim in (2, 3, 4): if not isinstance(input_array, np.ndarray): raise TypeError(msg + msg2 + str(dim) + msg3) else: if not input_array.ndim == dim: raise TypeError(msg + msg2 + str(dim) + msg3) elif isinstance(dim, tuple): if dim == (1, 2): msg_tup = "1 or 2" elif dim == (2, 3): msg_tup = "2 or 3" elif dim == (3, 4): msg_tup = "3 or 4" elif dim == (2, 3, 4): msg_tup = "2, 3 or 4" if isinstance(input_array, np.ndarray): if input_array.ndim not in dim: raise TypeError(msg + msg2 + msg_tup + msg3) else: raise TypeError(msg + msg2 + msg_tup + msg3)
def frame_or_shape(data): """ Sanitize ``data``, always return a 2d frame. If ``data`` is a 2d frame, it is returned unchanged. If it is a shaped, return an empty array of that shape. Parameters ---------- data : 2d ndarray or shape tuple Returns ------- array : 2d ndarray """ if isinstance(data, np.ndarray): array = data if array.ndim != 2: raise TypeError("`data` is not a frame or 2d array") elif isinstance(data, tuple): array = np.zeros(data) else: raise TypeError("`data` must be a tuple (shape) or a 2d array") return array def eval_func_tuple(f_args): """Takes a tuple of a function and args, evaluates and returns result""" return f_args[0](*f_args[1:]) class FixedObj(object): def __init__(self, v): self.v = v def iterable(v): """Helper function for ``pool_map``: prevents the argument from being wrapped in ``itertools.repeat()``. Examples -------- .. code-block:: python # we have a worker function whic processes a word: def worker(word, method): # ... # we want to process these words in parallel fasion: words = ["lorem", "ipsum", "esse", "ea", "eiusmod"] # but all with method = 1 # we then would use pool_map(3, worker, iterable(words), method) # this results in calling # # worker(words[0], 1) # worker(words[1], 1) # worker(words[2], 1) # ... """ return FixedObj(v) def pool_map(nproc, fkt, *args, **kwargs): """ Abstraction layer for multiprocessing. When ``nproc=1``, the builtin ``map()`` is used. For ``nproc>1`` a ``multiprocessing.Pool`` is created. Parameters ---------- nproc : int Number of processes to use. fkt : callable The function to be called with each ``*args`` *args : function arguments Arguments passed to ``fkt`` By default, ``itertools.repeat`` is applied on all the arguments, except when you wrap the argument in ``iterable()``. msg : str or None, optional Description to be displayed. progressbar_single : bool, optional Display a progress bar when single-processing is used. Defaults to ``False``. verbose : bool, optional Show more output. Also disables the progress bar when set to ``False``. Returns ------- res : list A list with the results. """ msg = kwargs.get("msg", None) verbose = kwargs.get("verbose", True) progressbar_single = kwargs.get("progressbar_single", False) _generator = kwargs.get("_generator", False) # not exposed in docstring args_r = [a.v if isinstance(a, FixedObj) else itt.repeat(a) for a in args] z = zip(itt.repeat(fkt), *args_r) if nproc == 1: if progressbar_single: total = len([a.v for a in args if isinstance(a, FixedObj)][0]) z = Progressbar(z, desc=msg, verbose=verbose, total=total) res = map(eval_func_tuple, z) if not _generator: res = list(res) else: multiprocessing.set_start_method("fork", force=True) from multiprocessing import Pool # deactivate multithreading os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" if verbose and msg is not None: print("{} with {} processes".format(msg, nproc)) pool = Pool(processes=nproc) if _generator: res = pool.imap(eval_func_tuple, z) else: res = pool.map(eval_func_tuple, z) pool.close() pool.join() # reactivate multithreading ncpus = multiprocessing.cpu_count() os.environ["MKL_NUM_THREADS"] = str(ncpus) os.environ["NUMEXPR_NUM_THREADS"] = str(ncpus) os.environ["OMP_NUM_THREADS"] = str(ncpus) return res def pool_imap(nproc, fkt, *args, **kwargs): """ Generator version of ``pool_map``. Useful when showing a progress bar for multiprocessing (see examples). Parameters ---------- nproc : int Number of processes to use. fkt : callable The function to be called with each ``*args`` *args : function arguments Arguments passed to ``fkt`` msg : str or None, optional Description to be displayed. progressbar_single : bool, optional Display a progress bar when single-processing is used. Defaults to ``True``. verbose : bool, optional Show more output. Also disables the progress bar when set to ``False``. Examples -------- .. code-block:: python # using pool_map res = pool_map(2, my_worker_function, *args) # using pool_imap with a progessbar: res = list(Progressbar(pool_imap(2, my_worker_function, *args))) """ kwargs["_generator"] = True return pool_map(nproc, fkt, *args, **kwargs) def repeat(*args): """ Applies ``itertools.repeat`` to every ``args``. Examples -------- # instead of using import itertools as itt my_fkt(itt.repeat(a), itt.repeat(b), itt.repeat(c), d, itt.repeat(e)) # you could use `repeat`: my_fkt(*repeat(a, b, c), d, *repeat(e)) """ return [itt.repeat(a) for a in args] def make_chunks(l, n): """ Chunks a list into ``n`` parts. The order of ``l`` is not kept. Useful for parallel processing when a single call is too fast, so the overhead from managing the processes is heavier than the calculation itself. Parameters ---------- l : list Input list. n : int Number of parts. Examples -------- .. code-block:: python make_chunks(range(13), 3) # -> [[0, 3, 6, 9, 12], [1, 4, 7, 10], [2, 5, 8, 11]] """ return [l[i::n] for i in range(n)] class redirect_output(object): """Context manager for redirecting stdout/err to files""" def __init__(self, stdout="", stderr=""): self.stdout = stdout self.stderr = stderr def __enter__(self): self.sys_stdout = sys.stdout self.sys_stderr = sys.stderr if self.stdout: sys.stdout = open(self.stdout, "w") if self.stderr: if self.stderr == self.stdout: sys.stderr = sys.stdout else: sys.stderr = open(self.stderr, "w") def __exit__(self, exc_type, exc_value, traceback): # TODO: close self.stdout and self.stderr sys.stdout = self.sys_stdout sys.stderr = self.sys_stderr def lines_of_code(): """Calculates the lines of code for VIP. Not objective measure of developer's work! (note to self). """ cur_path = os.path.dirname(os.path.abspath(__file__)) path = cur_path[: -len("conf")] ignore_set = set(["__init__.py"]) loclist = [] for pydir, _, pyfiles in os.walk(path): if "exlib/" not in pydir: for pyfile in pyfiles: if pyfile not in ignore_set and pyfile.endswith(".py"): totalpath = os.path.join(pydir, pyfile) loclist.append( ( len(open(totalpath, "r").read().splitlines()), totalpath.split(path)[1], ) ) for linenumbercount, filename in loclist: print("{:05d} lines in {}".format(linenumbercount, filename)) msg = "\nTotal: {} lines in ({}) excluding external libraries." print(msg.format(sum([x[0] for x in loclist]), path))