# -*- coding: utf-8 -*-
"""Utility functions.
import numpy as np

from sigpy import backend

__all__ = ['prod', 'vec', 'split', 'rss', 'resize',
           'flip', 'circshift', 'downsample', 'upsample', 'dirac', 'randn',
           'triang', 'hanning', 'monte_carlo_sure', 'axpy', 'xpay', 'leja']

def _normalize_axes(axes, ndim):
    if axes is None:
        return tuple(range(ndim))
        return tuple(a % ndim for a in sorted(axes))

def _normalize_shape(shape):
    if isinstance(shape, int):
        return (shape, )
        return tuple(shape)

def _expand_shapes(*shapes):

    shapes = [list(shape) for shape in shapes]
    max_ndim = max(len(shape) for shape in shapes)
    shapes_exp = [[1] * (max_ndim - len(shape)) + shape
                  for shape in shapes]

    return tuple(shapes_exp)

def _check_same_dtype(*arrays):

    dtype = arrays[0].dtype
    for a in arrays:
        if a.dtype != dtype:
            raise TypeError(
                'inputs dtype mismatch, got {a_dtype}, and {dtype}.'.format(
                    a_dtype=a.dtype, dtype=dtype))

def prod(shape):
    """Computes product of shape.

        shape (tuple or list): shape.


    return, dtype=np.int64)

def vec(inputs):
    """Vectorize inputs.

        shape (tuple or list): shape.

        array: Vectorized result.
    xp = backend.get_array_module(inputs[0])
    return xp.concatenate([i.ravel() for i in inputs])

def split(vec, oshapes):
    """Split input into specified output shapes.

        oshapes (list of tuple of ints): Output shapes.

        list of arrays: Splitted outputs.
    outputs = []
    for oshape in oshapes:
        osize = prod(oshape)
        vec = vec[osize:]

    return outputs

def rss(input, axes=(0, )):
    """Root sum of squares.

        input (array): Input array.
        axes (None or tuple of ints): Axes to perform operation.

        array: Result.
    xp = backend.get_array_module(input)
    return xp.sum(xp.abs(input)**2, axis=axes)**0.5

[docs]def resize(input, oshape, ishift=None, oshift=None): """Resize with zero-padding or cropping. Args: input (array): Input array. oshape (tuple of ints): Output shape. ishift (None or tuple of ints): Input shift. oshift (None or tuple of ints): Output shift. Returns: array: Zero-padded or cropped result. """ ishape1, oshape1 = _expand_shapes(input.shape, oshape) if ishape1 == oshape1: return input.reshape(oshape) if ishift is None: ishift = [max(i // 2 - o // 2, 0) for i, o in zip(ishape1, oshape1)] if oshift is None: oshift = [max(o // 2 - i // 2, 0) for i, o in zip(ishape1, oshape1)] copy_shape = [min(i - si, o - so) for i, si, o, so in zip(ishape1, ishift, oshape1, oshift)] islice = tuple([slice(si, si + c) for si, c in zip(ishift, copy_shape)]) oslice = tuple([slice(so, so + c) for so, c in zip(oshift, copy_shape)]) xp = backend.get_array_module(input) output = xp.zeros(oshape1, dtype=input.dtype) input = input.reshape(ishape1) output[oslice] = input[islice] return output.reshape(oshape)
[docs]def flip(input, axes=None): """Flip input. Args: input (array): Input array. axes (None or tuple of ints): Axes to perform operation. Returns: array: Flipped result. """ axes = _normalize_axes(axes, input.ndim) slc = [] for d in range(input.ndim): if d in axes: slc.append(slice(None, None, -1)) else: slc.append(slice(None)) slc = tuple(slc) output = input[slc] return output
[docs]def circshift(input, shifts, axes=None): """Circular shift input. Args: input (array): Input array. shifts (tuple of ints): Shifts. axes (None or tuple of ints): Axes to perform operation. Returns: array: Result. """ if axes is None: axes = range(input.ndim) assert(len(axes) == len(shifts)) xp = backend.get_array_module(input) for axis, shift in zip(axes, shifts): input = xp.roll(input, shift, axis=axis) return input
[docs]def downsample(input, factors, shift=None): """Downsample input. Args: input (array): Input array. factors (tuple of ints): Downsampling factors. shifts (None or tuple of ints): Shifts. Returns: array: Result. """ if shift is None: shift = [0] * len(factors) slc = tuple(slice(s, None, f) for s, f in zip(shift, factors)) return input[slc]
[docs]def upsample(input, oshape, factors, shift=None): """Upsample input. Args: input (array): Input array. factors (tuple of ints): Upsampling factors. shifts (None or tuple of ints): Shifts. Returns: array: Result. """ if shift is None: shift = [0] * len(factors) slc = tuple(slice(s, None, f) for s, f in zip(shift, factors)) xp = backend.get_array_module(input) output = xp.zeros(oshape, dtype=input.dtype) output[slc] = input return output
[docs]def dirac(shape, dtype=np.float, device=backend.cpu_device): """Create Dirac delta. Args: shape (tuple of ints): Output shape. dtype (Dtype): Output data-type. device (Device): Output device. Returns: array: Dirac delta array. """ device = backend.Device(device) xp = device.xp with device: return resize(xp.ones([1], dtype=dtype), shape)
[docs]def randn(shape, scale=1, dtype=np.float, device=backend.cpu_device): """Create random Gaussian array. Args: shape (tuple of ints): Output shape. scale (float): Standard deviation. dtype (Dtype): Output data-type. device (Device): Output device. Returns: array: Random Gaussian array. """ device = backend.Device(device) xp = device.xp with device: if np.issubdtype(dtype, np.complexfloating): real_dtype = np.array([], dtype=dtype).real.dtype real_shape = tuple(shape) + (2, ) output = xp.random.normal(size=real_shape, scale=scale / 2**0.5) output = output.astype(real_dtype) output = output.view(dtype=dtype).reshape(shape) return output else: return xp.random.normal(size=shape, scale=scale).astype(dtype)
[docs]def triang(shape, dtype=np.float, device=backend.cpu_device): """Create multi-dimensional triangular window. Args: shape (tuple of ints): Output shape. dtype (Dtype): Output data-type. device (Device): Output device. Returns: array: triangular filter. """ device = backend.Device(device) xp = device.xp shape = _normalize_shape(shape) with device: window = xp.ones(shape, dtype=dtype) for n, i in enumerate(shape[::-1]): x = xp.arange(i, dtype=dtype) w = 1 - xp.abs(x - i // 2 + ((i + 1) % 2) / 2) / ((i + 1) // 2) window *= w.reshape([i] + [1] * n) return window
[docs]def hanning(shape, dtype=np.float, device=backend.cpu_device): """Create multi-dimensional hanning window. Args: shape (tuple of ints): Output shape. dtype (Dtype): Output data-type. device (Device): Output device. Returns: array: hanning filter. """ device = backend.Device(device) xp = device.xp shape = _normalize_shape(shape) with device: window = xp.ones(shape, dtype=dtype) for n, i in enumerate(shape[::-1]): x = xp.arange(i, dtype=dtype) w = 0.5 - 0.5 * xp.cos(2 * np.pi * x / max(1, (i - (i % 2)))) window *= w.reshape([i] + [1] * n) return window
[docs]def monte_carlo_sure(f, y, sigma, eps=1e-10): """Monte Carlo Stein Unbiased Risk Estimator (SURE). Monte carlo SURE assumes the observation y = x + e, where e is a white Gaussian array with standard deviation sigma. Monte carlo SURE provides an unbiased estimate of mean-squared error, ie: 1 / n || f(y) - x ||_2^2 Args: f (function): x -> f(x). y (array): observed measurement. sigma (float): noise standard deviation. Returns: float: SURE. References: Ramani, S., Blu, T. and Unser, M. 2008. Monte-Carlo Sure: A Black-Box Optimization of Regularization Parameters for General Denoising Algorithms. IEEE Transactions on Image Processing 17, 9 (2008), 1540-1554. """ device = backend.get_device(y) xp = device.xp n = y.size f_y = f(y) b = randn(y.shape, dtype=y.dtype, device=device) divf_y = xp.real(xp.vdot(b, (f(y + eps * b) - f_y))) / eps sure = xp.mean(xp.abs(y - f_y)**2) - sigma**2 + 2 * sigma**2 * divf_y / n return sure
def leja(x): """ Perform leja ordering of roots of a polynomial. Orders roots in a way suitable to accurately compute polynomial coefficients. Args: x (array): roots to be ordered. Returns: array: ordered roots. References: Lang, M. and B. Frenzel. 1993. A New and Efficient Program for Finding All Polynomial Roots. Rice University ECE Technical Report, no. TR93-08, 1993. """ n = np.size(x) # duplicate roots to n+1 rows a = np.tile(np.reshape(x, (1, n)), (n+1, 1)) # take abs of first row a[0, :] = np.abs(a[0, :]) tmp = np.zeros(n+1, dtype=complex) # find index of max abs value ind = np.argmax(a[0, :]) if ind != 0: tmp[:] = a[:, 0] a[:, 0] = a[:, ind] a[:, ind] = tmp x_out = np.zeros(n, dtype=complex) x_out[0] = a[n-1, 0] # first entry of last row a[1, 1:] = np.abs(a[1, 1:] - x_out[0]) foo = a[0, 0:n] for l in range(1, n-1): foo = np.multiply(foo, a[l, :]) ind = np.argmax(foo[l:]) ind = ind + l if l != ind: tmp[:] = a[:, l] a[:, l] = a[:, ind] a[:, ind] = tmp # also swap inds in foo tmp[0] = foo[l] foo[l] = foo[ind] foo[ind] = tmp[0] x_out[l] = a[n-1, l] a[l+1, (l+1):n] = np.abs(a[l+1, (l+1):] - x_out[l]) x_out = a[n, :] return x_out
[docs]def axpy(y, a, x): """Compute y = a * x + y. Args: y (array): Output array. a (scalar or array): Input scalar. x (array): Input array. """ y += a * x
[docs]def xpay(y, a, x): """Compute y = x + a * y. Args: y (array): Output array. a (scalar or array): Input scalar. x (array): Input array. """ y *= a y += x