"""Various useful contexts.
"""
import collections
from contextlib import contextmanager
import functools
import math
import os
import signal
import time
import threading
import warnings
try:
# Try importing this here to prevent delays in contexts.
import IPython
except ImportError: # pragma: no cover
IPython = None
__all__ = [
"is_main_thread",
"NoInterrupt",
"FPS",
"CoroutineWrapper",
"nointerrupt",
"coroutine",
]
def sleep(duration):
"""More accurate sleep function for Mac OS X.
See https://stackoverflow.com/a/30672412/1088938
"""
start_time = time.perf_counter()
while True:
elapsed_time = time.perf_counter() - start_time
remaining_time = duration - elapsed_time
if remaining_time <= 0:
break
if remaining_time > 0.02: # Sleep for 5ms if remaining time is greater
time.sleep(
max(remaining_time / 2, 0.0001)
) # Sleep for the remaining time or minimum sleep interval
[docs]
def is_main_thread():
"""Return True if this is the main thread."""
return threading.current_thread() is threading.main_thread()
[docs]
class NoInterrupt:
"""Suspend the various signals during the execution block and a
simple mechanism to allow threads to be interrupted.
Arguments
---------
ignore : bool
If `True`, then do not raise a `KeyboardInterrupt` if a soft interrupt is
caught unless forced by multiple interrupt requests in a limited time.
There are two main entry points: globally by calling the :meth:`suspend()`
method, and within a :class:`NoInterrupt()` context.
**Main Thread**
When executed in a context from the main thread, a signal handler
is established which captures interrupt signals and represents
them instead as a boolean flag (conventionally called
"interrupted").
Global interrupt suppression can be enabled by creating a :class:`NoInterrupt()`
instance and calling :meth:`suspend()`. This will stay in effect until
:meth:`restore()` is called, a forcing interrupt is received, or the instance is
deleted. Additional calls to :meth:`suspend()` will reinstall the handlers, but
they will not be nested.
Interrupts can also be suspended in contexts. These can be nested. These instances
will become ``False`` at the end of the context.
**Auxiliary Threads**
Auxiliary threads can create instances of :class:`NoInterrupt()` or use contexts,
but cannot call :meth:`suspend()` or :meth:`restore()`. In these cases the context
does not suspend signals (see below), but the flag is still useful as it can act as
a signal force the auxiliary thread to terminate if an interrupt is received in the
main thread.
A couple of notes about using the context in auxiliary threads.
1. Either :meth:`suspend()` must be called globally or a context must
first be created in the main thread - otherwise the signal
handlers will not be installed. An exception will be raised if
an auxiliary thread tries to create a context without the
handlers being installed.
this case.
2. As stated in the python documents, signal handlers are always
executed in the main thread. Likewise, only the main thread is
allowed to set new signal handlers. Thus, the signal
interrupting facilities provided here only work properly in the
main thread. Also, forcing an interrupt cannot raise an
exception in the auxiliary threads: one must wait for them to
respond to the changed "interrupted" value.
For more information about killing threads see:
* http://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread-in-python
Attributes
----------
force_n : int
Number of interrupts to force signal.
force_timeout : float
Time in which force_n interrupts must be received to trigger a forced interrupt.
Examples
--------
The simplest use-cases look like these:
Simple context:
>>> with NoInterrupt():
... pass # do something
Context with a cleanly aborted loop:
>>> with NoInterrupt() as interrupted:
... done = False
... while not interrupted and not done:
... # Do something
... done = True
Map:
>>> NoInterrupt().map(abs, [1, -1, 2, -2])
[1, 1, 2, 2]
Keyboard interrupt signals are suspended during the execution of
the block unless forced by the user (3 rapid interrupts within
1s). Interrupts are ignored by default unless `ignore=False` is
specified, in which case they will be raised when the context is
ended.
If you want to control when you exit the block, use the
`interrupted` flag. This could be used, for example, while
plotting frames in an animation (see doc/Animation.ipynb).
Without the :class:`NoInterrupt()` context, if the user sends a keyboard
interrupt to the process while plotting, at best, a huge
stack-trace is produced, and at worst, the kernel will crash
(randomly depending on where the interrupt was received). With
this context, the interrupt will change `interrupted` to True so
you can exit the context when it is safe.
The last case is mapping a function to data. This will allow the
user to interrupt the process between function calls.
In the following examples we demonstrate this by simulating
interrupts
>>> import os, signal, time
>>> def simulate_interrupt(force=False):
... os.kill(os.getpid(), signal.SIGINT)
... if force:
... # Simulated a forced interrupt with multiple signals
... os.kill(os.getpid(), signal.SIGINT)
... os.kill(os.getpid(), signal.SIGINT)
... time.sleep(0.1) # Wait so signal can be received predictably
This loop will get interrupted in the middle so that m and n will not be
the same.
>>> def f(n, interrupted=False, force=False, interrupt=True):
... while n[0] < 10 and not interrupted:
... n[0] += 1
... if n[0] == 5 and interrupt:
... simulate_interrupt(force=force)
... n[1] += 1
>>> n = [0, 0]
>>> f(n, interrupt=False)
>>> n
[10, 10]
>>> n = [0, 0]
>>> try: # All doctests need to be wrapped in try blocks to not kill py.test!
... f(n)
... except KeyboardInterrupt as err:
... print("KeyboardInterrupt: {}".format(err))
KeyboardInterrupt:
>>> n
[5, 4]
Now we protect the loop from interrupts.
>>> n = [0, 0]
>>> try:
... with NoInterrupt(ignore=False) as interrupted:
... f(n)
... except KeyboardInterrupt as err:
... print("KeyboardInterrupt: {}".format(err))
KeyboardInterrupt:
>>> n
[10, 10]
One can ignore the exception if desired (this is the default as of 0.4.11):
>>> n = [0, 0]
>>> with NoInterrupt() as interrupted:
... f(n)
>>> n
[10, 10]
Three rapid exceptions will still force an interrupt when it occurs. This
might occur at random places in your code, so don't do this unless you
really need to stop the process.
>>> n = [0, 0]
>>> try:
... with NoInterrupt(ignore=False) as interrupted:
... f(n, force=True)
... except KeyboardInterrupt as err:
... print("KeyboardInterrupt: {}".format(err))
KeyboardInterrupt:
>>> n
[5, 4]
If `f()` is slow, we might want to interrupt it at safe times. This is
what the `interrupted` flag is for:
>>> n = [0, 0]
>>> try:
... with NoInterrupt(ignore=False) as interrupted:
... f(n, interrupted)
... except KeyboardInterrupt as err:
... print("KeyboardInterrupt: {}".format(err))
KeyboardInterrupt:
>>> n
[5, 5]
Again: the exception can be ignored
>>> n = [0, 0]
>>> with NoInterrupt() as interrupted:
... f(n, interrupted)
>>> n
[5, 5]
"""
# Each time a signal is raised, it is inserted into the
# _signals_raised dict and the corresponding entry of
# _signal_count is incremented. At the end of the final context
# of the main thread (outermost context) the dict of
# _signals_raised is cleared, but _signal_count is NOT reset. The
# value of _signal_count is stored in each instance to allow that
# instance to determined if a signal was raised in that context
# allowing threads to use the interrupted flag even if there is no
# active context in the main thread.
_instances = set()
_original_handlers = {} # Dictionary of original handlers
_signals_raised = {} # Dictionary if signals raised
_signal_count = {} # Dictionary of signal counts
_signals = set((signal.SIGINT, signal.SIGTERM))
_signals_suspended = set()
# Time, in seconds, for which force_n successive interrupts will
# toggle the default handler.
force_n = 3 # How to document?
force_timeout = 1
# Lock should be re-entrant (I think) since a signal might be sent during
# operation of one of the functions.
_lock = threading.RLock()
def __init__(self, ignore=True, timeout=None, unregister=False):
if unregister:
self.unregister()
with self._lock:
self.ignore = ignore
self._active = True
self.timeout = timeout
self.signal_count_at_start = dict(self._signal_count)
[docs]
@classmethod
def is_registered(cls):
"""Return True if handlers are registered."""
with cls._lock:
registered = bool(cls._signals.intersection(cls._original_handlers))
if False and registered:
assert all(
[
signal.getsignal(_signum) == cls.handle_signal
for _signum in cls._original_handlers
]
)
return registered
[docs]
@classmethod
def register(cls):
"""Register the handlers so that signals can be suspended."""
if not is_main_thread():
_msg = " ".join(
[
"Can only register handlers from the main thread."
"(Called from thread {})".format(threading.get_ident())
]
)
raise RuntimeError(_msg)
with cls._lock:
if not cls.is_registered():
cls._original_handlers = {
_signum: signal.signal(_signum, cls.handle_signal)
for _signum in cls._signals
}
assert cls.is_registered()
[docs]
@classmethod
def unregister(cls, full=False):
"""Reset handlers to the original values. No more signal suspension.
Arguments
---------
full : bool
If True, do a full reset, including counts.
"""
with cls._lock:
while cls._original_handlers:
_signum, _handler = cls._original_handlers.popitem()
signal.signal(_signum, _handler)
if full:
cls.reset()
cls._signal_count = {}
if not full:
assert not cls.is_registered()
[docs]
@classmethod
def set_signals(cls, signals):
"""Change the signal handlers.
Note: This does not change the signals listed in :attr:`_suspended_signals` list.
Arguments
---------
signals : set()
Set of signal numbers.
"""
signals = set(signals)
with cls._lock:
if cls.is_registered() and signals != cls._signals:
cls.unregister()
cls._signals = set(signals)
cls.register()
[docs]
@classmethod
def suspend(cls, signals=None):
"""Suspends the specified signals."""
with cls._lock:
if signals is None:
signals = cls._signals
for signum in signals:
if signum not in cls._original_handlers:
warnings.warn(
" ".join(
[
"No handler registered for signal {}.",
"Signal will not be suspended.",
]
).format(signum)
)
cls._signals_suspended.add(signum)
[docs]
@classmethod
def resume(cls, signals=None):
"""Resumes the specified signals."""
if signals is None:
signals = set(cls._signals_suspended)
for signum in signals:
cls._signals_suspended.discard(signum)
[docs]
@classmethod
def reset(cls):
"""Reset the signal logs and return last signal `(signum, frame, time)`."""
res = None
with cls._lock:
if hasattr(cls, "_last_signal"):
res = cls._last_signal
del cls._last_signal
cls._signals_raised = {}
return res
[docs]
@classmethod
def handle_signal(cls, signum, frame):
"""Custom signal handler.
This stores the signal for later processing unless it was
forced or there are no current contexts, in which case the
original handlers will be called.
"""
with cls._lock:
cls._last_signal = (signum, frame, time.perf_counter())
cls._signals_raised.setdefault(signum, [])
cls._signals_raised[signum].append(cls._last_signal)
cls._signal_count.setdefault(signum, 0)
cls._signal_count[signum] += 1
if cls._forced_interrupt(signum) or signum not in cls._signals_suspended:
cls.handle_original_signal(signum=signum, frame=frame)
[docs]
@classmethod
def handle_original_signal(cls, signum, frame):
"""Call the original handler."""
# This is a bit tricky because python does not provide a
# default handler for SIGTERM so we can't simply use it.
handler = cls._original_handlers[signum]
if handler:
handler(signum, frame)
else:
if cls.is_registered():
cls.unregister()
os.kill(os.getpid(), signum)
cls.register()
else:
os.kill(os.getpid(), signum)
@classmethod
def _forced_interrupt(cls, signum):
"""Return True if :attr:`force_n` interrupts have been recieved in the past
:attr:`force_timeout` seconds
"""
with cls._lock:
signals_raised = cls._signals_raised.get(signum, [])
return cls.force_n <= len(signals_raised) and cls.force_timeout > (
signals_raised[-1][-1] - signals_raised[-cls.force_n][-1]
)
#############
# Dummy handlers to thwart ipykernel's attempts to restore the
# default signal handlers.
# https://github.com/ipython/ipykernel/issues/328
@staticmethod
def _pre_handler_hook():
pass
@staticmethod
def _post_handler_hook():
pass
[docs]
def __enter__(self):
"""Enter context."""
with self._lock:
self.tic = time.perf_counter()
if IPython:
kernel = getattr(IPython.get_ipython(), "kernel", None)
if kernel:
kernel.pre_handler_hook = self._pre_handler_hook
kernel.post_handler_hook = self._post_handler_hook
self._active = True
self.signal_count_at_start = dict(self._signal_count)
if is_main_thread():
if not self.is_registered():
self.register()
self.suspend()
NoInterrupt._instances.add(self)
elif not self.is_registered():
_msg = "\n".join(
[
"Thread {} entering unregistered NoInterrupt() context.",
"Interrupts will not be processed! "
+ "Call register() in main thread.",
]
).format(threading.get_ident())
warnings.warn(_msg)
return self
def __exit__(self, exc_type, exc_value, traceback):
with self._lock:
self._active = False
if not is_main_thread():
return
self._instances.remove(self)
if not self._instances:
# Only raise an exception if all the instances have been
# cleared, otherwise we might still be in a protected
# context somewhere.
self.resume()
last_signal = self.reset()
if last_signal and not self.ignore:
# Call original handler.
signum, frame, _time = last_signal
self.handle_original_signal(signum=signum, frame=frame)
if IPython:
kernel = getattr(IPython.get_ipython(), "kernel", None)
if kernel:
del kernel.pre_handler_hook
del kernel.post_handler_hook
[docs]
def __bool__(self):
"""Return True if interrupted."""
with self._lock:
timeout = self.timeout and time.perf_counter() - self.tic > self.timeout
return (
not self._active
or timeout
or any(
[
self._signal_count.get(_signum, 0)
> self.signal_count_at_start.get(_signum, 0)
for _signum in self._signals
]
)
)
__nonzero__ = __bool__ # For python 2.
[docs]
def map(self, function, sequence, *v, **kw):
"""Map function onto sequence until interrupted or done.
Interrupts will not occur inside function() unless forced.
"""
res = []
with self as interrupted:
for s in sequence:
if interrupted:
break
res.append(function(s, *v, **kw))
return res
[docs]
def nointerrupt(f):
"""Decorator that suspends signals and passes an interrupted flag
to the protected function. Can only be called from the main
thread: will raise a RuntimeError otherwise (use `@interrupted` instead).
Examples
--------
>>> @nointerrupt
... def f(interrupted):
... for n in range(3):
... if interrupted:
... break
... print(n)
... time.sleep(0.1)
>>> f()
0
1
2
"""
_msg = " ".join(
"@nointerrupt function called from non-main thread {}."
"(Use @interrupt instead)."
)
@functools.wraps(f)
def wrapper(*v, **kw):
if not is_main_thread():
raise RuntimeError(_msg.format(threading.get_ident()))
with NoInterrupt() as interrupted:
kw.setdefault("interrupted", interrupted)
return f(*v, **kw)
return wrapper
[docs]
class CoroutineWrapper(object):
"""Wrapper for coroutine contexts that allows them to function as a context but also as
a function. Similar to :func:`open()` which may be used both in a function or as a
file object. Note: be sure to call :meth:`close()` if you do not use this as a
context.
"""
def __init__(self, coroutine):
self.coroutine = coroutine
self.started = False
def __enter__(self, *v, **kw):
self.res = next(self.coroutine) # Prime the coroutine
self.started = True
return self.send
def __exit__(self, type, value, tb):
self.close()
return
[docs]
def send(self, *v):
self.res = self.coroutine.send(*v)
return self.res
def __call__(self, *v):
if not self.started:
self.__enter__()
return self.send(*v)
[docs]
def close(self):
self.coroutine.close()
[docs]
def coroutine(coroutine):
"""Decorator for a context that yields a function from a coroutine.
This allows you to write functions that maintain state between calls. The
use as a context here ensures that the coroutine is closed.
Examples
--------
Here is an example based on that suggested by Thomas Kluyver:
http://takluyver.github.io/posts/readable-python-coroutines.html
>>> @coroutine
... def get_have_seen(case_sensitive=False):
... seen = set() # Set of words already seen. This is the "state"
... word = (yield)
... while True:
... if not case_sensitive:
... word = word.lower()
... result = word in seen
... seen.add(word)
... word = (yield result)
>>> with get_have_seen(case_sensitive=False) as have_seen:
... print(have_seen("hello"))
... print(have_seen("hello"))
... print(have_seen("Hello"))
... print(have_seen("hi"))
... print(have_seen("hi"))
False
True
True
False
True
>>> have_seen("hi")
Traceback (most recent call last):
...
StopIteration
You can also use this as a function (like open()) but don't forget to close
it.
>>> have_seen = get_have_seen(case_sensitive=True)
>>> have_seen("hello")
False
>>> have_seen("hello")
True
>>> have_seen("Hello")
False
>>> have_seen("hi")
False
>>> have_seen("hi")
True
>>> have_seen.close()
>>> have_seen("hi")
Traceback (most recent call last):
...
StopIteration
"""
# @contextlib.contextmanager
@functools.wraps(coroutine)
def wrapper(*v, **kw):
return CoroutineWrapper(coroutine(*v, **kw))
# primed_coroutine = coroutine(*v, **kw)
# next(primed_coroutine)
# yield primed_coroutine.send
# primed_coroutine.close()
return wrapper
class FPS_Frames:
"""Helper class for FPS.
This is the object returned in the with `FPS()` context.
Attributes
----------
the_frames : iterable
Iterable of the actual frames.
frames : int, None
Number of frames or None if unavailable.
frame : int
Number of the current frame.
fps : float
Current frames-per-second, base on last `max_tics` iterations.
max_fps : float, None
If provided, then sleep to rate-limit the iterations to this maximum fps.
max_tics : int
How many of the last updates will be used to calculate the fps. (All if None).
"""
def __init__(self, interrupted, the_frames_or_frames, max_fps, max_tics):
self.interrupted = interrupted
self.frames, self.the_frames = self.get_frames_and_the_frames(
the_frames_or_frames
)
self.max_tics = max_tics
self.max_fps = max_fps
self._reset()
@staticmethod
def get_frames_and_the_frames(the_frames_or_frames):
"""Return `(frames, the_frames)`.
Returns
-------
frames : int or None
Number of frames if finite or know, otherwise None.
the_frames : iterable
The actual frames.
"""
try: # Do we have a length?
frames = len(the_frames_or_frames)
the_frames = the_frames_or_frames
except TypeError: # No. Infinite/indefinite or number
try:
the_frames = iter(the_frames_or_frames)
frames = None
except TypeError: # Not an iterator. Delegate to range().
the_frames = range(the_frames_or_frames)
frames = the_frames_or_frames
return frames, the_frames
def _reset(self):
self._frame = -1
self.tic = time.perf_counter()
self.tics = collections.deque([self.tic], maxlen=self.max_tics)
def __bool__(self):
"""True while running"""
return not bool(self.interrupted) and (
self.frames is None or self.frame < self.frames
)
@property
def frame(self):
return self._frame
@frame.setter
def frame(self, frame):
# SMELLS. assert (frame == self._frame + 1)?
if frame > 0 and frame == self._frame + 1:
self.tics.append(time.perf_counter())
self._frame = frame
@property
def fps(self):
assert 0 < len(self.tics)
if len(self.tics) == 1:
return math.nan
else:
return (len(self.tics) - 1) / (self.tics[-1] - self.tics[0])
def __str__(self):
return "{:.2f}".format(self.fps)
def __format__(self, format_spec):
return float(self.fps).__format__(format_spec)
def __iter__(self):
self._reset()
the_frames = self.the_frames
if not self.max_fps:
for frame in the_frames:
if not self:
break
self.frame += 1
yield frame
else:
# More complicated version with rate limiting.
dt = 1.0 / self.max_fps
tic = None
for frame in the_frames:
if not self:
break
toc = time.perf_counter()
if tic is not None:
wait = max(0, tic + dt - toc)
sleep(wait)
# Our sleep makes this a little more accurate since time.sleep() can
# take longer than wait
# https://stackoverflow.com/questions/1133857
toc += wait
tic = toc
self.frame += 1
yield frame
def __float__(self):
return self.fps
[docs]
class FPS:
"""Context manager to measure framerate and provide interrupt control.
This can be used in two ways:
1. As an iterator, which will run for the specified number of frames or until the
timeout is exceeded;
2. As flag that can be run while `bool(fps)`. In this second usage, you must
manually update `fps.frame` to get a proper fps computation.
The `fps` instance can be used as to display the current frame-rate.
Arguments
---------
frames : int | iterable
Yields iterator or range.
timeout : float, None
Timeout in seconds. If `len(frames)` is finite, then the default is -1
(corresponding to no timeout).
max_fps : float, None
If provided, then sleep to rate-limit the iterations to this maximum fps.
unregister : bool
If `True`, then call NoInterrupt.unregister() before the loop.
max_tics : int, None
Maximum number of last updates used to calculate the fps. Will use all if None.
Examples
--------
>>> import numpy as np
>>> fps = FPS(frames=0.1*np.arange(10), timeout=10)
>>> for t in fps:
... print(f"t={t:.1f}: fps={fps:.0f}")
... sleep(0.1)
t=0.0: fps=nan
t=0.1: fps=10...
t=0.2: fps=10...
t=0.3: fps=10...
t=0.4: fps=10...
t=0.5: fps=10...
t=0.6: fps=10...
t=0.7: fps=10...
t=0.8: fps=10...
t=0.9: fps=10...
>>> print(1/np.diff(fps.tics))
[9... 9... 9... 9... 9... 9... 9... 9... 9...]
This actually creates a context under the hood. If you want to be explicit, you can
do something like this:
>>> with FPS(frames=0.1*np.arange(10), timeout=10) as fps:
... for t in fps:
... print(f"t={t:.1f}: fps={fps:0.0f}")
... sleep(0.1)
... print("Loop done!")
... print(1/np.diff(fps.tics))
t=0.0: fps=nan
t=0.1: fps=10...
t=0.2: fps=10...
t=0.3: fps=10...
t=0.4: fps=10...
t=0.5: fps=10...
t=0.6: fps=10...
t=0.7: fps=10...
t=0.8: fps=10...
t=0.9: fps=10...
Loop done!
[9... 9... 9... 9... 9... 9... 9... 9... 9...]
Note that you can get the results after if needed:
>>> print(f"{fps:.0f}")
10...
But you can't do this:
>>> fps = FPS(frames=0.1*np.arange(10), timeout=10)
>>> print(fps)
Traceback (most recent call last):
...
ValueError: Unintialized FPS object: please use in a context or after a loop.
If you don't need to report the actual performance, you can just use the FPS class
as an iterator. This will break only at the start of the loop if interrupted, or if
the timeout is exceeded.
>>> import numpy as np
>>> for t in FPS(frames=0.1*np.arange(10), timeout=10):
... print(f"t={t:.1f}")
... sleep(0.1)
t=0.0
t=0.1
t=0.2
t=0.3
t=0.4
t=0.5
t=0.6
t=0.7
t=0.8
t=0.9
"""
_default_timeout = 10
def __init__(
self, frames=200, timeout=None, max_fps=None, unregister=True, max_tics=20
):
self.frames = frames
self.timeout = timeout
self.max_tics = max_tics
self.max_fps = max_fps
self.unregister = unregister
def __enter__(self):
fps = FPS_Frames(
interrupted=None,
the_frames_or_frames=self.frames,
max_tics=self.max_tics,
max_fps=self.max_fps,
)
if fps.frames is None and self.timeout is None:
self.timeout = self._default_timeout
self._interrupted = NoInterrupt(
timeout=self.timeout, unregister=self.unregister
)
fps.interrupted = self._interrupted.__enter__()
return fps
def __exit__(self, exc_type, exc_value, traceback):
return self._interrupted.__exit__(exc_type, exc_value, traceback)
def __iter__(self):
with self as self._frames_obj:
for frame in self._frames_obj:
yield frame
[docs]
def __str__(self):
"""Delegate to self._frames if it exists."""
return str(self._frames_obj)
def __format__(self, format_spec):
return self._frames_obj.__format__(format_spec)
@property
def tics(self):
return getattr(self._frames_obj, "tics")
def __getattr__(self, key):
if key == "_frames_obj":
raise ValueError(
f"Unintialized {self.__class__.__name__} object:"
+ " please use in a context or after a loop."
)
elif hasattr(self, "_frames_obj"):
return getattr(self._frames_obj, key)
else:
return super().__getattr__(key)
def __len__(self):
frames, the_frames = FPS_Frames.get_frames_and_the_frames(self.frames)
if frames is None:
raise TypeError(f"object of type '{self.__class__.__name__}' has no len()")
return frames