Source code for mmfutils.performance.fft

"""FFTW wrappers for high-performance computing.

This module requires you to have installed the fftw libraries and
:mod:`pyfftw`.  Note that you must build the fftw with all precisions
using something like::

    PREFIX=/data/apps/fftw
    VER=3.3.4
    for opt in " " "--enable-sse2 --enable-single" \\
               "--enable-long-double" "--enable-quad-precision"; do
      ./configure --prefix="${PREFIX}/${VER}"\\
                  --enable-threads\\
                  --enable-shared\\
                  $opt
      make -j8 install
    done

Note: The FFTW library does not work with negative indices for axis.
Indices should first be normalized by ``inds % len(shape)``.
"""
import functools
import itertools
import warnings

import numpy.fft
import numpy as np

from .threads import SET_THREAD_HOOKS

del numpy

__all__ = ['fft', 'ifft', 'fftn', 'ifftn', 'fftfreq', 'resample']


# Numpy versions with a default axis specified
def fft_numpy(Phi, axis=-1):
    return np.fft.fft(Phi, axis=axis)


def ifft_numpy(Phit, axis=-1):
    return np.fft.ifft(Phit, axis=axis)


def fftn_numpy(Phi, axes=None):
    return np.fft.fftn(Phi, axes=axes)


def ifftn_numpy(Phit, axes=None):
    return np.fft.ifftn(Phit, axes=axes)


fftfreq = np.fft.fftfreq
fftshift = np.fft.fftshift


_THREADS = 8
_PLANNER_EFFORT = 'FFTW_MEASURE'


def set_num_threads(nthreads):
    global _THREADS
    _THREADS = nthreads


SET_THREAD_HOOKS.add(set_num_threads)


try:                     # NOQA  This is too complex, but that is okay
    import pyfftw
    from pyfftw.interfaces.numpy_fft import (fft as _fft,
                                             ifft as _ifft,
                                             fftn as _fftn,
                                             ifftn as _ifftn)

    # Hack to get the version from pyfftw to resolve issue #25
    from pkg_resources import parse_version
    pyfftw_version = getattr(pyfftw, 'version', '0')
    while hasattr(pyfftw_version, 'version'):
        pyfftw_version = pyfftw_version.version

    # By default, pyfftw does not cache the plans.  Here we enable the cache
    # and set the keepalive time to an hour.
    if parse_version(pyfftw_version) >= parse_version('0.9.2'):
        pyfftw.interfaces.cache.enable()
        pyfftw.interfaces.cache.set_keepalive_time(60*60)

    # Also, the number of threads is set by default to 1.  Here we set the
    # default value to 8 and use FFT_MEASURE to actually check.
    @functools.wraps(_fft)
    def fft_pyfftw(*v, **kw):
        global _THREADS, _PLANNER_EFFORT
        kw.update(threads=_THREADS, planner_effort=_PLANNER_EFFORT)
        if 'axis' in kw:
            # Support negative arguments for the axis keyword
            dim = len(np.shape(v[0]))
            kw['axis'] = (kw['axis'] + dim) % dim
        return _fft(*v, **kw)

    @functools.wraps(_ifft)
    def ifft_pyfftw(*v, **kw):
        global _THREADS, _PLANNER_EFFORT
        kw.update(threads=_THREADS, planner_effort=_PLANNER_EFFORT)
        if 'axis' in kw:
            # Support negative arguments for the axis keyword
            dim = len(np.shape(v[0]))
            kw['axis'] = (kw['axis'] + dim) % dim
        return _ifft(*v, **kw)

    @functools.wraps(_fftn)
    def fftn_pyfftw(*v, **kw):
        global _THREADS, _PLANNER_EFFORT
        kw.update(threads=_THREADS, planner_effort=_PLANNER_EFFORT)
        if 'axes' in kw:
            # Support negative arguments for the axis keyword
            dim = len(np.shape(v[0]))
            kw['axes'] = (np.asarray(kw['axes']) + dim) % dim
        return _fftn(*v, **kw)

    @functools.wraps(_ifftn)
    def ifftn_pyfftw(*v, **kw):
        global _THREADS, _PLANNER_EFFORT
        kw.update(threads=_THREADS, planner_effort=_PLANNER_EFFORT)
        if 'axes' in kw:
            # Support negative arguments for the axis keyword
            dim = len(np.shape(v[0]))
            kw['axes'] = (np.asarray(kw['axes']) + dim) % dim
        return _ifftn(*v, **kw)

    def get_fft_pyfftw(a, n=None, axis=-1, overwrite_input=False,
                       auto_align_input=True, auto_contiguous=True,
                       avoid_copy=False):
        """Return a function to compute the fft."""
        global _THREADS, _PLANNER_EFFORT
        # Support negative arguments for the axis keyword
        dim = len(np.shape(a))
        axis = (axis + dim) % dim
        return pyfftw.builders.fft(a=a, n=n, axis=axis,
                                   threads=_THREADS,
                                   planner_effort=_PLANNER_EFFORT,
                                   overwrite_input=overwrite_input,
                                   auto_align_input=auto_align_input,
                                   auto_contiguous=auto_contiguous,
                                   avoid_copy=avoid_copy)

    def get_ifft_pyfftw(a, n=None, axis=-1, overwrite_input=False,
                        auto_align_input=True, auto_contiguous=True,
                        avoid_copy=False):
        """Return a function to compute the ifft."""
        global _THREADS, _PLANNER_EFFORT
        dim = len(np.shape(a))
        axis = (axis + dim) % dim
        return pyfftw.builders.ifft(a=a, n=n, axis=axis,
                                    threads=_THREADS,
                                    planner_effort=_PLANNER_EFFORT,
                                    overwrite_input=overwrite_input,
                                    auto_align_input=auto_align_input,
                                    auto_contiguous=auto_contiguous,
                                    avoid_copy=avoid_copy)

    def get_fftn_pyfftw(a, s=None, axes=None, overwrite_input=False,
                        auto_align_input=True, auto_contiguous=True,
                        avoid_copy=False):
        """Return a function to compute the fftn."""
        global _THREADS, _PLANNER_EFFORT
        if axes is not None:
            dim = len(np.shape(a))
            axes = (np.asarray(axes) + dim) % dim
        return pyfftw.builders.fftn(a=a, s=s, axes=axes,
                                    threads=_THREADS,
                                    planner_effort=_PLANNER_EFFORT,
                                    overwrite_input=overwrite_input,
                                    auto_align_input=auto_align_input,
                                    auto_contiguous=auto_contiguous,
                                    avoid_copy=avoid_copy)

    def get_ifftn_pyfftw(a, s=None, axes=None, overwrite_input=False,
                         auto_align_input=True, auto_contiguous=True,
                         avoid_copy=False):
        """Return a function to compute the ifftn."""
        global _THREADS, _PLANNER_EFFORT
        if axes is not None:
            dim = len(np.shape(a))
            axes = (np.asarray(axes) + dim) % dim
        return pyfftw.builders.ifftn(a=a, s=s, axes=axes,
                                     threads=_THREADS,
                                     planner_effort=_PLANNER_EFFORT,
                                     overwrite_input=overwrite_input,
                                     auto_align_input=auto_align_input,
                                     auto_contiguous=auto_contiguous,
                                     avoid_copy=avoid_copy)

    fft = fft_pyfftw
    ifft = ifft_pyfftw
    fftn = fftn_pyfftw
    ifftn = ifftn_pyfftw
except ImportError:              # pragma: nocover
    warnings.warn("Could not import pyfftw... falling back to numpy")
    fft = fft_numpy
    ifft = ifft_numpy
    fftn = fftn_numpy
    ifftn = ifftn_numpy


[docs]def resample(f, N): """Resample f to a new grid of size N. This uses the FFT to resample the function `f` on a new grid with `N` points. Note: this assumes that the function `f` is periodic. Resampling non-periodic functions to finer lattices may introduce aliasing artifacts. Arguments --------- f : array The function to be resampled. May be n-dimensional N : int or array The number of lattice points in the new array. If this is an integer, then all dimensions of the output array will have this length. Examples -------- >>> def f(x, y): ... "Function with only low frequencies" ... return (np.sin(2*np.pi*x)-np.cos(4*np.pi*y)) >>> L = 1.0 >>> Nx, Ny = 16, 13 # Small grid >>> NX, NY = 31, 24 # Large grid >>> dx, dy = L/Nx, L/Ny >>> dX, dY = L/NX, L/NY >>> x = (np.arange(Nx)*dx - L/2)[:, None] >>> y = (np.arange(Ny)*dy - L/2)[None, :] >>> X = (np.arange(NX)*dX - L/2)[:, None] >>> Y = (np.arange(NY)*dY - L/2)[None, :] >>> f_XY = resample(f(x,y), (NX, NY)) >>> np.allclose(f_XY, f(X,Y)) # To larger grid True >>> np.allclose(resample(f_XY, (Nx, Ny)), f(x,y)) # Back down True """ newshape = np.array(f.shape) newshape[...] = N axes = np.where(np.not_equal(f.shape, newshape))[0] fk = fftn(f, axes=axes) fk1 = np.zeros(newshape, dtype=complex) for _s in itertools.product( *((slice(0, (_N + 1) // 2), slice(-(_N - 1) // 2, None)) for _N in np.minimum(f.shape, newshape))): fk1[_s] = fk[_s] return ifftn(fk1, axes=axes) * np.prod(newshape.astype(float)/f.shape)