Source code for sigpy.fourier

# -*- coding: utf-8 -*-
"""FFT and non-uniform FFT (NUFFT) functions.

"""
import numpy as np

from math import ceil
from sigpy import backend, interp, util


__all__ = ['fft', 'ifft', 'nufft', 'nufft_adjoint', 'estimate_shape',
           'toeplitz_psf']


[docs]def fft(input, oshape=None, axes=None, center=True, norm='ortho'): """FFT function that supports centering. Args: input (array): input array. oshape (None or array of ints): output shape. axes (None or array of ints): Axes over which to compute the FFT. norm (Nonr or ``"ortho"``): Keyword to specify the normalization mode. Returns: array: FFT result of dimension oshape. See Also: :func:`numpy.fft.fftn` """ xp = backend.get_array_module(input) if not np.issubdtype(input.dtype, np.complexfloating): input = input.astype(np.complex64) if center: output = _fftc(input, oshape=oshape, axes=axes, norm=norm) else: output = xp.fft.fftn(input, s=oshape, axes=axes, norm=norm) if np.issubdtype(input.dtype, np.complexfloating) and input.dtype != output.dtype: output = output.astype(input.dtype, copy=False) return output
[docs]def ifft(input, oshape=None, axes=None, center=True, norm='ortho'): """IFFT function that supports centering. Args: input (array): input array. oshape (None or array of ints): output shape. axes (None or array of ints): Axes over which to compute the inverse FFT. norm (None or ``"ortho"``): Keyword to specify the normalization mode. Returns: array of dimension oshape. See Also: :func:`numpy.fft.ifftn` """ xp = backend.get_array_module(input) if not np.issubdtype(input.dtype, np.complexfloating): input = input.astype(np.complex64) if center: output = _ifftc(input, oshape=oshape, axes=axes, norm=norm) else: output = xp.fft.ifftn(input, s=oshape, axes=axes, norm=norm) if np.issubdtype(input.dtype, np.complexfloating) and input.dtype != output.dtype: output = output.astype(input.dtype) return output
[docs]def nufft(input, coord, oversamp=1.25, width=4): """Non-uniform Fast Fourier Transform. Args: input (array): input signal domain array of shape (..., n_{ndim - 1}, ..., n_1, n_0), where ndim is specified by coord.shape[-1]. The nufft is applied on the last ndim axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimensions to apply the nufft. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. Returns: array: Fourier domain data of shape input.shape[:-ndim] + coord.shape[:-1]. References: Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation IEEE Transactions on Signal Processing, 51(2), 560-574. Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. IEEE transactions on medical imaging, 24(6), 799-808. """ ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 os_shape = _get_oversamp_shape(input.shape, ndim, oversamp) output = input.copy() # Apodize _apodize(output, ndim, oversamp, width, beta) # Zero-pad output /= util.prod(input.shape[-ndim:])**0.5 output = util.resize(output, os_shape) # FFT output = fft(output, axes=range(-ndim, 0), norm=None) # Interpolate coord = _scale_coord(coord, input.shape, oversamp) output = interp.interpolate( output, coord, kernel='kaiser_bessel', width=width, param=beta) output /= width**ndim return output
[docs]def estimate_shape(coord): """Estimate array shape from coordinates. Shape is estimated by the different between maximum and minimum of coordinates in each axis. Args: coord (array): Coordinates. """ ndim = coord.shape[-1] with backend.get_device(coord): shape = [int(coord[..., i].max() - coord[..., i].min()) for i in range(ndim)] return shape
[docs]def nufft_adjoint(input, coord, oshape=None, oversamp=1.25, width=4): """Adjoint non-uniform Fast Fourier Transform. Args: input (array): input Fourier domain array of shape (...) + coord.shape[:-1]. That is, the last dimensions of input must match the first dimensions of coord. The nufft_adjoint is applied on the last coord.ndim - 1 axes, and looped over the remaining axes. coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimension to apply nufft adjoint. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. oshape (tuple of ints): output shape of the form (..., n_{ndim - 1}, ..., n_1, n_0). oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. Returns: array: signal domain array with shape specified by oshape. See Also: :func:`sigpy.nufft.nufft` """ ndim = coord.shape[-1] beta = np.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 if oshape is None: oshape = list(input.shape[:-coord.ndim + 1]) + estimate_shape(coord) else: oshape = list(oshape) os_shape = _get_oversamp_shape(oshape, ndim, oversamp) # Gridding coord = _scale_coord(coord, oshape, oversamp) output = interp.gridding(input, coord, os_shape, kernel='kaiser_bessel', width=width, param=beta) output /= width**ndim # IFFT output = ifft(output, axes=range(-ndim, 0), norm=None) # Crop output = util.resize(output, oshape) output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5 # Apodize _apodize(output, ndim, oversamp, width, beta) return output
def toeplitz_psf(coord, shape, oversamp=1.25, width=4): """Toeplitz PSF for fast Normal non-uniform Fast Fourier Transform. While fast, this is more memory intensive. Args: coord (array): Fourier domain coordinate array of shape (..., ndim). ndim determines the number of dimension to apply nufft adjoint. coord[..., i] should be scaled to have its range between -n_i // 2, and n_i // 2. shape (tuple of ints): shape of the form (..., n_{ndim - 1}, ..., n_1, n_0). This is the shape of the input array of the forward nufft. oversamp (float): oversampling factor. width (float): interpolation kernel full-width in terms of oversampled grid. Returns: array: PSF to be used by the normal operator defined in `sigpy.linop.NUFFT` See Also: :func:`sigpy.linop.NUFFT` """ xp = backend.get_array_module(coord) with backend.get_device(coord): ndim = coord.shape[-1] new_shape = _get_oversamp_shape(shape, ndim, 2) new_coord = _scale_coord(coord, new_shape, 2) idx = [slice(None)]*len(new_shape) for k in range(-1, -(ndim + 1), -1): idx[k] = new_shape[k]//2 d = xp.zeros(new_shape, dtype=xp.complex64) d[tuple(idx)] = 1 psf = nufft(d, new_coord, oversamp, width) psf = nufft_adjoint(psf, new_coord, d.shape, oversamp, width) fft_axes = tuple(range(-1, -(ndim + 1), -1)) psf = fft(psf, axes=fft_axes, norm=None) * (2**ndim) return psf def _fftc(input, oshape=None, axes=None, norm='ortho'): ndim = input.ndim axes = util._normalize_axes(axes, ndim) xp = backend.get_array_module(input) if oshape is None: oshape = input.shape tmp = util.resize(input, oshape) tmp = xp.fft.ifftshift(tmp, axes=axes) tmp = xp.fft.fftn(tmp, axes=axes, norm=norm) output = xp.fft.fftshift(tmp, axes=axes) return output def _ifftc(input, oshape=None, axes=None, norm='ortho'): ndim = input.ndim axes = util._normalize_axes(axes, ndim) xp = backend.get_array_module(input) if oshape is None: oshape = input.shape tmp = util.resize(input, oshape) tmp = xp.fft.ifftshift(tmp, axes=axes) tmp = xp.fft.ifftn(tmp, axes=axes, norm=norm) output = xp.fft.fftshift(tmp, axes=axes) return output def _scale_coord(coord, shape, oversamp): ndim = coord.shape[-1] output = coord.copy() for i in range(-ndim, 0): scale = ceil(oversamp * shape[i]) / shape[i] shift = ceil(oversamp * shape[i]) // 2 output[..., i] *= scale output[..., i] += shift return output def _get_oversamp_shape(shape, ndim, oversamp): return list(shape)[:-ndim] + [ceil(oversamp * i) for i in shape[-ndim:]] def _apodize(input, ndim, oversamp, width, beta): xp = backend.get_array_module(input) output = input for a in range(-ndim, 0): i = output.shape[a] os_i = ceil(oversamp * i) idx = xp.arange(i, dtype=output.dtype) # Calculate apodization apod = (beta**2 - (np.pi * width * (idx - i // 2) / os_i)**2)**0.5 apod /= xp.sinh(apod) output *= apod.reshape([i] + [1] * (-a - 1)) return output