#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
This module contains generic class- and numerical-related utilities
"""
import warnings
from collections import defaultdict
from typing import Union, Callable, Optional, Tuple
from textwrap import dedent
import numpy as np
from scipy.interpolate import interp1d as sc_interp1d
__all__ = ["find_nearest", "iter_current_next", "partition", "iter_subclasses"]
class _OrderedSet(dict):
def add(self, x):
self[x] = None
def remove(self, x):
self.pop(x, None)
def find_nearest(array, value):
"""
Return the index of the value in `array` that is closest to `value`.
Parameters
----------
array: ndarray
Array of numbers
value: float
The query value
Returns
-------
int
Index of closest value
"""
idx = np.searchsorted(array, value, side="left")
if idx > 0 and (
idx == len(array)
or np.abs(value - array[idx - 1]) <= np.abs(value - array[idx])
):
return idx - 1
else:
return idx
# we need a globally unique value to detect whether a keyword argument was
# passed to iter_current_next
_sentinel = object()
def iter_current_next(iterable, start=_sentinel, end=_sentinel):
"""Iterate over pairs of consecutive values in an iterable.
This creates generator that yields a (current, next) tuple per element. If the
iterable contains less than two elements a StopIteration exception is
raised.
Parameters
----------
iterable: iterable
Iterable to take values from
start: object, optional
If specified, this value will be treated as if it were the first element
of the iterator
end: object, optional
If specified, this value will be treated as if it were the last element
of the iterator
Yields
------
(object, object)
Pairs of items
Examples
--------
>>> for pair in iter_current_next([]):
... print(pair)
>>> for pair in iter_current_next([0]):
... print(pair)
>>> for pair in iter_current_next([0, 1, 2]):
... print(pair)
(0, 1)
(1, 2)
>>> for pair in iter_current_next([0, 1, 2], start=None):
... print(pair)
(None, 0)
(0, 1)
(1, 2)
>>> for pair in iter_current_next([0, 1, 2], end='end_value'):
... print(pair)
(0, 1)
(1, 2)
(2, 'end_value')
>>> for pair in iter_current_next([], start='start', end='end'):
... print(pair)
('start', 'end')
"""
iterable = iter(iterable)
cur = start
try:
if cur is _sentinel:
cur = next(iterable)
while True:
nxt = next(iterable)
yield (cur, nxt)
cur = nxt
except StopIteration:
if cur is not _sentinel and end is not _sentinel:
yield (cur, end)
def iter_subclasses(cls, _seen=None):
"""
iter_subclasses(cls)
Generator over all subclasses of a given class, in depth first order.
Examples
--------
>>> class A(object): pass
>>> class B(A): pass
>>> class C(A): pass
>>> class D(B,C): pass
>>> class E(D): pass
>>>
>>> for cls in iter_subclasses(A):
... print(cls.__name__)
B
D
E
C
>>> # get ALL (new-style) classes currently defined
>>> [cls.__name__ for cls in iter_subclasses(object)] #doctest: +ELLIPSIS
['type', ...'tuple', ...]
"""
if not isinstance(cls, type):
raise TypeError(
"iter_subclasses must be called with " "new-style classes, not %.100r" % cls
)
if _seen is None:
_seen = set()
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
for sub in subs:
if sub not in _seen:
_seen.add(sub)
yield sub
for sub in iter_subclasses(sub, _seen):
yield sub
class ReplaceRefMixin(object):
"""This class is a utility mixin class to replace references to
objects with references to other objects. This is functionality is
used when unfolding timelines.
To use this functionality, a class should inherit from this class,
and keep a list of all attributes that contain references.
Examples
--------
The following class defines `prev` as a referential attribute, to
be replaced when a class instance is copied:
>>> class MyClass(ReplaceRefMixin):
... def __init__(self, prev=None):
... super().__init__()
... self.prev = prev
... self._ref_attrs.append('prev')
Create two instance `a1` and `a2`, where `a1` is the `prev` of
`a2``
>>> a1 = MyClass()
>>> a2 = MyClass(a1)
Copy `a1` and `a2` to `b1` and `b2`, respectively:
>>> from copy import copy
>>> b1 = copy(a1)
>>> b2 = copy(a2)
After copying the prev of `b2` is `a1`, not `b1`:
>>> b2.prev == b1
False
>>> b2.prev == a1
True
To fix that we define an object map:
>>> object_map = {a1: b1, a2: b2}
and replace the references according to the map:
>>> b2.replace_refs(object_map)
>>> b2.prev == b1
True
"""
def __init__(self):
self._ref_attrs = []
def replace_refs(self, o_map):
if hasattr(self, "_ref_attrs"):
for attr in self._ref_attrs:
o = getattr(self, attr)
if o is None:
pass
elif isinstance(o, list):
o_list_new = []
for o_el in o:
if o_el in o_map:
o_list_new.append(o_map[o_el])
else:
warnings.warn(
dedent(
"""reference not found in
o_map: {} start={} end={}, substituting None
""".format(
o_el, o_el.start, o_el.end
)
)
)
o_list_new.append(None)
setattr(self, attr, o_list_new)
else:
if o in o_map:
o_new = o_map[o]
else:
warnings.warn(
dedent(
"""reference not found in o_map:
{} start={} end={}, substituting None
""".format(
o, o.start, o.end
)
)
)
o_new = None
setattr(self, attr, o_new)
class ComparableMixin(object):
"""
Mixin class that makes instances comparable in a rich way (i.e. in !=, <, <=
etc), by just implementing a _cmpkey() method that returns a comparable
value.
source:
http://regebro.wordpress.com/2010/12/13/python-implementing-rich-comparison-the-correct-way/
Examples
--------
>>> class MyClass(ComparableMixin):
... def __init__(self, x):
... self.x = x
... def _cmpkey(self):
... return self.x
>>>
>>> a = MyClass(3)
>>> b = MyClass(4)
>>> a == b
False
>>> a < b
True
"""
def _compare(self, other, method):
try:
return method(self._cmpkey(), other._cmpkey())
except (AttributeError, TypeError):
# _cmpkey not implemented, or return different type,
# so I can't compare with "other".
return NotImplemented
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def partition(func, iterable):
"""
Return a dictionary containing the equivalence classes (actually bags)
of `iterable`, partioned according to `func`. The value of a key `k` is the
list of all elements `e` from iterable such that `k = func(e)`.
Examples
========
The following example groups the integers from 0 to 10 by their
respective modulo 3 values:
>>> lst = range(10)
>>> partition(lambda x: x % 3, lst)
{0: [0, 3, 6, 9], 1: [1, 4, 7], 2: [2, 5, 8]}
"""
result = defaultdict(list)
for v in iterable:
result[func(v)].append(v)
return dict(result)
def add_field(a, descr):
"""
Return a new array that is like `a`, but has additional fields.
The contents of `a` are copied over to the appropriate fields in
the new array, whereas the new fields are uninitialized. The
arguments are not modified. Source [8]_.
Parameters
----------
a: np.ndarray
A structured numpy array
descr: np.dtype or list
A numpy type description of the new fields
Returns
-------
np.ndarray
The new structured numpy array
Examples
--------
>>> import numpy as np
>>> sa = np.array([(1, 'Foo'), (2, 'Bar')], \
dtype=[('id', int), ('name', 'S3')])
>>> sa.dtype.descr == np.dtype([('id', int), ('name', 'S3')])
True
>>> sb = add_field(sa, [('score', float)])
>>> sb.dtype.descr == np.dtype([('id', int), ('name', 'S3'), \
('score', float)])
True
>>> np.all(sa['id'] == sb['id'])
True
>>> np.all(sa['name'] == sb['name'])
True
Notes
-----
Source:
.. [8]
https://stackoverflow.com/questions/1201817/\
adding-a-field-to-a-structured-numpy-array
"""
if a.dtype.fields is None:
raise ValueError("`A` must be a structured numpy array")
if isinstance(descr, np.dtype):
descr = descr.descr
b = np.empty(a.shape, dtype=a.dtype.descr + descr)
for name in a.dtype.names:
b[name] = a[name]
return b
def sorted_dict_items(items, key=None):
for item in sorted(items, key=key):
yield item
[docs]def show_diff(a, b):
"""
Show the difference between two strings, using the difflib package. The
difference is printed to stdout.
Parameters
----------
a: str
First string
b: str
Second string
"""
import difflib
differ = difflib.Differ()
for li in differ.compare(a.split("\n"), b.split("\n")):
print(li)
class PrettyPrintTree(object):
def __init__(self):
self.stack = []
def push(self):
self.stack.append(TreeSymbol())
def pop(self):
self.stack.pop()
def next_item(self):
assert len(self.stack) > 0
self.stack[-1].next_item()
def last_item(self):
assert len(self.stack) > 0
self.stack[-1].last_item()
def __str__(self):
return "".join(str(sym) for sym in self.stack)
class TreeSymbol(object):
def __init__(self):
self.symbols = [" │ ", " ├─ ", " └─ ", " "]
self.state = 0
def next_item(self):
self.state = 1
def last_item(self):
self.state = 2
def __str__(self):
sym = self.symbols[self.state]
if self.state == 1:
self.state = 0
elif self.state == 2:
self.state = 3
return sym
def search(states, success, expand, combine):
while len(states) > 0:
state = states.pop(0)
if success(state):
return state
else:
states = combine(expand(state), states)
def interp1d(
x: np.ndarray,
y: np.ndarray,
dtype: Optional[type] = None,
axis: int = -1,
kind: Union[str, int] = "linear",
copy=True,
bounds_error=False,
fill_value=np.nan,
assume_sorted=False,
) -> Callable[[Union[float, int, np.ndarray]], np.ndarray]:
"""
Interpolate a 1-D function using scipy's interp1d method. This utility allows for
handling the case where `x` and `y` are only a single value (i.e. have length one,
which results in a ValueError if using scipy's version directly). It also allows for
specifying the dtype of the output.
The description of the parameters has been taken from `scipy.interpolate.interp1d`.
`x` and `y` are arrays of values used to approximate some function f:
``y = f(x)``. This class returns a function whose call method uses
interpolation to find the value of new points.
Parameters
----------
x : (N,) np.ndarray
A 1-D array of real values.
y : (...,N,...) np.ndarray
A N-D array of real values. The length of `y` along the interpolation
axis must be equal to the length of `x`.
dtype : type, optional
Type of the output array (e.g., `float`, `int`). By default it is set to
None (i.e., the array will have the same type as the outputs from
scipy's interp1d method.
axis : int, optional
Specifies the axis of `y` along which to interpolate.
Interpolation defaults to the last axis of `y`.
kind : str or int, optional
Specifies the kind of interpolation as a string or as an integer
specifying the order of the spline interpolator to use.
The string has to be one of 'linear', 'nearest', 'nearest-up', 'zero',
'slinear', 'quadratic', 'cubic', 'previous', or 'next'. 'zero',
'slinear', 'quadratic' and 'cubic' refer to a spline interpolation of
zeroth, first, second or third order; 'previous' and 'next' simply
return the previous or next value of the point; 'nearest-up' and
'nearest' differ when interpolating half-integers (e.g. 0.5, 1.5)
in that 'nearest-up' rounds up and 'nearest' rounds down. Default
is 'linear'.
copy : bool, optional
If True, the class makes internal copies of x and y.
If False, references to `x` and `y` are used. The default is to copy.
bounds_error : bool, optional
If True, a ValueError is raised any time interpolation is attempted on
a value outside of the range of x (where extrapolation is
necessary). If False, out of bounds values are assigned `fill_value`.
By default, an error is raised unless ``fill_value="extrapolate"``.
fill_value : array-like or (array-like, array_like) or "extrapolate", optional
- if a ndarray (or float), this value will be used to fill in for
requested points outside of the data range. If not provided, then
the default is NaN. The array-like must broadcast properly to the
dimensions of the non-interpolation axes.
- If a two-element tuple, then the first element is used as a
fill value for ``x_new < x[0]`` and the second element is used for
``x_new > x[-1]``. Anything that is not a 2-element tuple (e.g.,
list or ndarray, regardless of shape) is taken to be a single
array-like argument meant to be used for both bounds as
``below, above = fill_value, fill_value``.
- If "extrapolate", then points outside the data range will be
extrapolated.
assume_sorted : bool, optional
If False, values of `x` can be in any order and they are sorted first.
If True, `x` has to be an array of monotonically increasing values.
Returns
-------
interp_fun : callable
The interpolator instance. This method takes an input array, float
or integer and returns an array with the specified dtype (if `dtype`
is not None).
"""
if len(x) > 1:
interp_fun = sc_interp1d(
x=x,
y=y,
kind=kind,
axis=axis,
copy=copy,
bounds_error=bounds_error,
fill_value=fill_value,
assume_sorted=assume_sorted,
)
else:
# If there is only one value for x and y, assume that the method
# will always return the same value for any input.
def interp_fun(
input_var: Union[float, int, np.ndarray]
) -> Callable[[Union[float, int, np.ndarray]], np.ndarray]:
if y.ndim > 1:
result = np.broadcast_to(y, (len(np.atleast_1d(input_var)), y.shape[1]))
else:
result = np.broadcast_to(y, (len(np.atleast_1d(input_var)),))
if not isinstance(input_var, np.ndarray):
# the output of scipy's interp1d is always an array
result = np.array(result[0])
return result
if dtype is not None:
def typed_interp(
input_var: Union[float, int, np.ndarray]
) -> Callable[[Union[float, int, np.ndarray]], np.ndarray]:
return interp_fun(input_var).astype(dtype)
return typed_interp
else:
return interp_fun
# def search_recursive(states, success, expand, combine):
# try:
# if not states:
# return None
# elif success(states[0]):
# return states[0]
# else:
# new_states = combine(expand(states[0]), states[1:])
# return search_recursive(new_states, success, expand, combine)
# except RecursionError:
# warnings.warn('search exhausted stack, bailing out')
# return None
def monotonize_times(
s: np.ndarray,
x: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Interpolate linearly over as many points in `s` as necessary to
obtain a monotonic sequence. The minimum and maximum of `s` are
prepended and appended, respectively, to ensure monotonicity at
the bounds of `s`.
Parameters
----------
s : np.ndarray
a sequence of numbers s(x) which we want to monotonize
x : np.ndarray or None
The input variable of sequence s(x).
Returns
-------
s_mono: np.ndarray
a monotonic sequence that has been linearly interpolated using a subset of s
x_mono: np.ndarray
The input of the monotonic sequence.
"""
eps = np.finfo(float).eps
_s = np.r_[np.min(s) - eps, s, np.max(s) + eps]
if x is not None:
_x = np.r_[np.min(x) - eps, x, np.max(x) + eps]
else:
_x = np.r_[-eps, np.arange(len(s)), len(s) - 1 + eps]
s_mono = np.maximum.accumulate(s)
mask = np.r_[False, True, (np.diff(s_mono) != 0), False]
x_mono = _x[1:-1]
s_mono = interp1d(_x[mask], _s[mask], fill_value="extrapolate")(x_mono)
return s_mono, x_mono
if __name__ == "__main__":
import doctest
doctest.testmod()