Source code for mmfutils.plot.errors

"""Tools for plotting error bands etc."""
import numpy as np

import scipy.stats
import scipy as sp

import matplotlib.collections
from matplotlib import pyplot as plt

from .rasterize import ListCollection

del scipy

__all__ = ['plot_errorbars', 'plot_err', 'error_line']


[docs]def plot_errorbars(x, y, dx=None, dy=None, colour='', linestyle='', pointstyle='.', barwidth=0.5, **kwargs): if pointstyle != '' or linestyle != '': plt.plot(x, y, pointstyle + linestyle + colour, **kwargs) elif dx is None and dy is None: # Plot points if both error bars are not drawn. plt.plot(x, y, '.' + colour, **kwargs) if dx is not None: xmax = x + dx xmin = x - dx if dy is not None: ymax = y + dy ymin = y - dy for n in range(len(x)): if dx is not None: plt.plot([xmin[n], xmax[n]], [y[n], y[n]], '-|' + colour, lw=barwidth) if dy is not None: plt.plot([x[n], x[n]], [ymin[n], ymax[n]], '-_' + colour, lw=barwidth)
[docs]def plot_err(x, y, yerr=None, xerr=None, **kwarg): """Plot x vs. y with errorbars. Right now we support the following cases: x = 1D, y = 1D """ if (1 == len(x.shape) and 1 == len(y.shape)): plt.errorbar(x, y, yerr=yerr, xerr=xerr, **kwarg) elif (1 == len(x.shape) and 1 < len(y.shape)): plot_axis = np.where(np.array(y.shape) == len(x))[0][0] y = y.swapaxes(0, plot_axis) Nx, Ny = y.shape for n in Ny: plt.errorbar(x, y[:, n], **kwarg) elif (max(x.shape) == np.prod(x.shape)): plot_axis = np.argmax(x.shape) x = x.ravel() y = y.swapaxes(0, plot_axis) if yerr is not None: yerr = yerr.swapaxes(0, plot_axis) Nx, Ny = y.shape for n in range(Ny): if yerr is None: plt.errorbar(x, y[:, n], xerr=xerr, **kwarg) else: plt.errorbar(x, y[:, n], xerr=xerr, yerr=yerr[:, n], **kwarg) else: plt.plot(x, y, **kwarg)
[docs]def error_line(x, y, dy, fgc='k', bgc='w', N=20, fill=True): """Plots a curve (x, y) with gaussian errors dy represented by shading out to 5 dy.""" yp0 = y ym0 = y pdf = sp.stats.norm().pdf to_rgb = plt.matplotlib.colors.ColorConverter().to_rgb bg_colour = np.array(to_rgb(bgc)) fg_colour = np.array(to_rgb(fgc)) if fill: patches = [] else: lines = [] ax = plt.gca() for sigma in np.linspace(0, 5, N)[1:]: yp = y+dy*sigma ym = y-dy*sigma c = pdf(sigma)/pdf(0.0) # colour = fg_colour*c + (1.0-c)*bg_colour colour = fg_colour if fill: X = np.hstack((x, np.flipud(x))) Y = np.hstack((yp0, np.flipud(yp))) patches.extend( ax.fill(X, Y, fc=colour, ec=colour, lw=0, alpha=c)) X = np.hstack((x, np.flipud(x))) Y = np.hstack((ym0, np.flipud(ym))) patches.extend( ax.fill(X, Y, fc=fg_colour, ec=fg_colour, lw=0, alpha=c)) else: lines.extend( ax.plot(x, yp, color=colour, alpha=c) + ax.plot(x, ym, color=fg_colour*c+(1.0-c)*bg_colour)) ym0 = ym yp0 = yp if fill: artists = [matplotlib.collections.PatchCollection(patches)] else: if False: # Can't add alphas to LineCollection unfortunately. args = dict( zip(['segments', 'linewidths', 'colors', 'antialiaseds', 'linestyles'], zip(*[(_l.get_xydata(), _l.get_linewidth(), _l.get_color(), _l.get_antialiased(), _l.get_linestyle()) for _l in lines]))) artists = [matplotlib.collections.LineCollection(**args)] else: artists = [ListCollection(lines)] # Remove individual lines from the axis... for _l in lines: _l.remove() # ... and add back as a collection. ax.add_collection(artists[0]) return artists