Source code for sigpy.conv

# -*- coding: utf-8 -*-
"""Convolution functions with  multi-dimension, and multi-channel support.

"""
import numpy as np
import scipy.signal as signal
from sigpy import backend, util, config


__all__ = ['convolve', 'convolve_data_adjoint', 'convolve_filter_adjoint']


[docs]def convolve(data, filt, mode='full', strides=None, multi_channel=False): r"""Convolution that supports multi-dimensional and multi-channel inputs. This function follows the signal processing definition of convolution. Note that the cuDNN version only supports inputs with D=1, 2 or 3. Args: data (array): data array of shape: :math:`[..., m_1, ..., m_D]` if multi_channel is False, :math:`[..., c_i, m_1, ..., m_D]` otherwise. filt (array): filter array of shape: :math:`[n_1, ..., n_D]` if multi_channel is False :math:`[c_o, c_i, n_1, ..., n_D]` otherwise. mode (str): {'full', 'valid'}. strides (None or tuple of ints): convolution strides of length D. multi_channel (bool): specify if input/output has multiple channels. Returns: array: output array of shape: :math:`[..., p_1, ..., p_D]` if multi_channel is False, :math:`[..., c_o, p_1, ..., p_D]` otherwise. """ xp = backend.get_array_module(data) if xp == np: output = _convolve(data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: # pragma: no cover if config.cudnn_enabled: if np.issubdtype(data.dtype, np.floating): output = _convolve_cuda(data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: output = _complex(_convolve_cuda, data, filt, mode=mode, strides=strides, multi_channel=multi_channel) else: raise RuntimeError( 'cudnn must be installed to perform convolution on GPU.') return output
[docs]def convolve_data_adjoint(output, filt, data_shape, mode='full', strides=None, multi_channel=False): """Adjoint convolution operation with respect to data. Note that the cuDNN version only supports inputs with D=1, 2 or 3. Args: output (array): output array of shape :math:`[..., p_1, ..., p_D]` if multi_channel is False, :math:`[..., c_o, p_1, ..., p_D]` otherwise. filt (array): filter array of shape :math:`[n_1, ..., n_D]` if multi_channel is False :math:`[c_o, c_i, n_1, ..., n_D]` otherwise. mode (str): {'full', 'valid'}. strides (None or tuple of ints): convolution strides of length D. multi_channel (bool): specify if input/output has multiple channels. multi_channel (bool): specify if data/output has multiple channels. mode (str): {'full', 'valid'}. Returns: array: data array of shape :math:`[..., m_1, ..., m_D]` if multi_channel is False, :math:`[..., c_i, m_1, ..., m_D]` otherwise. """ data_shape = tuple(data_shape) xp = backend.get_array_module(output) if xp == np: data = _convolve_data_adjoint(output, filt, data_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: # pragma: no cover if config.cudnn_enabled: if np.issubdtype(output.dtype, np.floating): data = _convolve_data_adjoint_cuda( output, filt, data_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: data = _complex(_convolve_data_adjoint_cuda, output, filt.conj(), data_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: raise RuntimeError( 'cudnn must be installed to perform convolution on GPU.') return data
[docs]def convolve_filter_adjoint(output, data, filt_shape, mode='full', strides=None, multi_channel=False): """Adjoint convolution operation with respect to filter. Note that the cuDNN version only supports inputs with D=1, 2 or 3. Args: output (array): output array of shape: :math:`[..., p_1, ..., p_D]` if multi_channel is False, :math:`[..., c_o, p_1, ..., p_D]` otherwise. data (array): data array of shape: :math:`[..., m_1, ..., m_D]` if multi_channel is False, :math:`[..., c_i, m_1, ..., m_D]` otherwise. mode (str): {'full', 'valid'}. strides (None or tuple of ints): convolution strides of length D. multi_channel (bool): specify if input/output has multiple channels. Returns: array: filter array of shape: :math:`[n_1, ..., n_D]` if multi_channel is False :math:`[c_o, c_i, n_1, ..., n_D]` otherwise. """ filt_shape = tuple(filt_shape) xp = backend.get_array_module(data) if xp == np: filt = _convolve_filter_adjoint(output, data, filt_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: # pragma: no cover if config.cudnn_enabled: if np.issubdtype(output.dtype, np.floating): filt = _convolve_filter_adjoint_cuda( output, data, filt_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: filt = _complex(_convolve_filter_adjoint_cuda, output, data.conj(), filt_shape, mode=mode, strides=strides, multi_channel=multi_channel) else: raise RuntimeError( 'cudnn must be installed to perform convolution on GPU.') return filt
def _get_convolve_params(data_shape, filt_shape, mode, strides, multi_channel): D = len(filt_shape) - 2 * multi_channel m = tuple(data_shape[-D:]) n = tuple(filt_shape[-D:]) b = tuple(data_shape[:-D - multi_channel]) B = util.prod(b) if multi_channel: if filt_shape[-D - 1] != data_shape[-D - 1]: raise ValueError('Data channel mismatch, ' 'got {} from data and {} from filt.'.format( data_shape[-D - 1], filt_shape[-D - 1])) c_i = filt_shape[-D - 1] c_o = filt_shape[-D - 2] else: c_i = 1 c_o = 1 if strides is None: s = (1, ) * D else: if len(strides) != D: raise ValueError('Strides must have length {}.'.format(D)) s = tuple(strides) if mode == 'full': p = tuple((m_d + n_d - 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)) elif mode == 'valid': if (any(m_d >= n_d for m_d, n_d in zip(m, n)) and any(m_d < n_d for m_d, n_d in zip(m, n))): raise ValueError('In valid mode, either data or filter must be ' 'at least as large as the other in every axis.') p = tuple((m_d - n_d + 1 + s_d - 1) // s_d for m_d, n_d, s_d in zip(m, n, s)) else: raise ValueError('Invalid mode, got {}'.format(mode)) return D, b, B, m, n, s, c_i, c_o, p def _convolve(data, filt, mode='full', strides=None, multi_channel=False): D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt.shape, mode, strides, multi_channel) # Normalize shapes. data = data.reshape((B, c_i) + m) filt = filt.reshape((c_o, c_i) + n) output = np.zeros((B, c_o) + p, dtype=data.dtype) slc = tuple(slice(None, None, s_d) for s_d in s) for k in range(B): for j in range(c_o): for i in range(c_i): output[k, j] += signal.convolve( data[k, i], filt[j, i], mode=mode)[slc] # Reshape. if multi_channel: output = output.reshape(b + (c_o, ) + p) else: output = output.reshape(b + p) return output def _convolve_data_adjoint(output, filt, data_shape, mode='full', strides=None, multi_channel=False): D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data_shape, filt.shape, mode, strides, multi_channel) # Normalize shapes. output = output.reshape((B, c_o) + p) filt = filt.reshape((c_o, c_i) + n) data = np.zeros((B, c_i) + m, dtype=output.dtype) slc = tuple(slice(None, None, s_d) for s_d in s) if mode == 'full': output_kj = np.zeros([m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype) adjoint_mode = 'valid' elif mode == 'valid': output_kj = np.zeros([max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)], dtype=output.dtype) if all(m_d >= n_d for m_d, n_d in zip(m, n)): adjoint_mode = 'full' else: adjoint_mode = 'valid' for k in range(B): for j in range(c_o): for i in range(c_i): output_kj[slc] = output[k, j] data[k, i] += signal.correlate( output_kj, filt[j, i], mode=adjoint_mode) # Reshape. data = data.reshape(data_shape) return data def _convolve_filter_adjoint(output, data, filt_shape, mode='full', strides=None, multi_channel=False): D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt_shape, mode, strides, multi_channel) # Normalize shapes. data = data.reshape((B, c_i) + m) output = output.reshape((B, c_o) + p) slc = tuple(slice(None, None, s_d) for s_d in s) if mode == 'full': output_kj = np.zeros([m_d + n_d - 1 for m_d, n_d in zip(m, n)], dtype=output.dtype) adjoint_mode = 'valid' elif mode == 'valid': output_kj = np.zeros([max(m_d, n_d) - min(m_d, n_d) + 1 for m_d, n_d in zip(m, n)], dtype=output.dtype) if all(m_d >= n_d for m_d, n_d in zip(m, n)): adjoint_mode = 'valid' else: adjoint_mode = 'full' filt = np.zeros((c_o, c_i) + n, dtype=output.dtype) for k in range(B): for j in range(c_o): for i in range(c_i): output_kj[slc] = output[k, j] filt[j, i] += signal.correlate( output_kj, data[k, i], mode=adjoint_mode) # Reshape. filt = filt.reshape(filt_shape) return filt if config.cudnn_enabled: # pragma: no cover from cupy import cudnn def _complex(func, data1, data2, *kargs, **kwargs): """Helper function to convert func to support complex floats. """ xp = backend.get_array_module(data1) data1r = xp.real(data1) data1i = xp.imag(data1) data2r = xp.real(data2) data2i = xp.imag(data2) outputr = func(data1r, data2r, *kargs, **kwargs) outputr -= func(data1i, data2i, *kargs, **kwargs) outputi = func(data1i, data2r, *kargs, **kwargs) outputi += func(data1r, data2i, *kargs, **kwargs) output = outputr + 1j * outputi output = output.astype(data1.dtype, copy=False) return output def _convolve_cuda(data, filt, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(data) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt.shape, mode, strides, multi_channel) if D == 1: return _convolve_cuda( xp.expand_dims(data, -1), xp.expand_dims(filt, -1), mode=mode, strides=list(strides) + [1] if strides is not None else None, multi_channel=multi_channel).squeeze(-1) elif D > 3: raise ValueError( f'cuDNN convolution only supports 1, 2, or 3D, got {D}.') dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D data = data.reshape((B, c_i) + m) filt = filt.reshape((c_o, c_i) + n) output = xp.empty((B, c_o) + p, dtype=data.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_forward(data, filt, None, output, pads, s, dilations, groups, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. if multi_channel: output = output.reshape(b + (c_o, ) + p) else: output = output.reshape(b + p) return output def _convolve_data_adjoint_cuda(output, filt, data_shape, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(output) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data_shape, filt.shape, mode, strides, multi_channel) if D == 1: return _convolve_data_adjoint_cuda( xp.expand_dims(output, -1), xp.expand_dims(filt, -1), list(data_shape) + [1], mode=mode, strides=list(strides) + [1] if strides is not None else None, multi_channel=multi_channel).squeeze(-1) elif D > 3: raise ValueError( f'cuDNN convolution only supports 1, 2 or 3D, got {D}.') dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D output = output.reshape((B, c_o) + p) filt = filt.reshape((c_o, c_i) + n) data = xp.empty((B, c_i) + m, dtype=output.dtype) filt = util.flip(filt, axes=range(-D, 0)) cudnn.convolution_backward_data(filt, output, None, data, pads, s, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) # Reshape. data = data.reshape(data_shape) return data def _convolve_filter_adjoint_cuda(output, data, filt_shape, mode='full', strides=None, multi_channel=False): xp = backend.get_array_module(data) D, b, B, m, n, s, c_i, c_o, p = _get_convolve_params( data.shape, filt_shape, mode, strides, multi_channel) if D == 1: return _convolve_filter_adjoint_cuda( xp.expand_dims(output, -1), xp.expand_dims(data, -1), list(filt_shape) + [1], mode=mode, strides=list(strides) + [1] if strides is not None else None, multi_channel=multi_channel).squeeze(-1) elif D > 3: raise ValueError( f'cuDNN convolution only supports 1, 2 or 3D, got {D}.') dilations = (1, ) * D groups = 1 auto_tune = True tensor_core = 'auto' deterministic = False if mode == 'full': pads = tuple(n_d - 1 for n_d in n) elif mode == 'valid': pads = (0, ) * D data = data.reshape((B, c_i) + m) output = output.reshape((B, c_o) + p) filt = xp.empty((c_o, c_i) + n, dtype=output.dtype) cudnn.convolution_backward_filter(data, output, filt, pads, s, dilations, groups, deterministic=deterministic, auto_tune=auto_tune, tensor_core=tensor_core) filt = util.flip(filt, axes=range(-D, 0)) filt = filt.reshape(filt_shape) return filt