# -*- coding: utf-8 -*-
"""MRI linear operators.
This module mainly contains the Sense linear operator,
which integrates multi-channel coil sensitivity maps and
discrete Fourier transform.
"""
import sigpy as sp
[docs]def Sense(mps, coord=None, weights=None, tseg=None, ishape=None,
coil_batch_size=None, comm=None, transp_nufft=False):
"""Sense linear operator.
Args:
mps (array): sensitivity maps of length = number of channels.
coord (None or array): coordinates.
weights (None or array): k-space weights.
Useful for soft-gating or density compensation.
tseg (None or Dictionary): parameters for time-segmented off-resonance
correction. Parameters are 'b0' (array), 'dt' (float),
'lseg' (int), and 'n_bins' (int). Lseg is the number of
time segments used, and n_bins is the number of histogram bins.
ishape (None or tuple): image shape.
coil_batch_size (None or int): batch size for processing multi-channel.
When None, process all coils at the same time.
Useful for saving memory.
comm (None or `sigpy.Communicator`): communicator
for distributed computing.
"""
# Get image shape and dimension.
num_coils = len(mps)
if ishape is None:
ishape = mps.shape[1:]
img_ndim = mps.ndim - 1
else:
img_ndim = len(ishape)
# Serialize linop if coil_batch_size is smaller than num_coils.
num_coils = len(mps)
if coil_batch_size is None:
coil_batch_size = num_coils
if coil_batch_size < len(mps):
num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size
A = sp.linop.Vstack(
[Sense(mps[c*coil_batch_size:((c+1)*coil_batch_size)],
coord=coord, weights=weights, ishape=ishape)
for c in range(num_coil_batches)], axis=0)
if comm is not None:
C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True)
A = A * C
return A
# Create Sense linear operator
S = sp.linop.Multiply(ishape, mps)
if tseg is None:
if coord is None:
F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0))
else:
if transp_nufft is False:
F = sp.linop.NUFFT(S.oshape, coord)
else:
F = sp.linop.NUFFT(S.oshape, -coord).H
A = F * S
# If B0 provided, perform time-segmented off-resonance compensation
else:
if transp_nufft is False:
F = sp.linop.NUFFT(S.oshape, coord)
else:
F = sp.linop.NUFFT(S.oshape, -coord).H
time = len(coord) * tseg['dt']
b, ct = sp.mri.util.tseg_off_res_b_ct(tseg['b0'], tseg['n_bins'],
tseg['lseg'], tseg['dt'], time)
for ii in range(tseg['lseg']):
Bi = sp.linop.Multiply(F.oshape, b[:, ii])
Cti = sp.linop.Multiply(S.ishape, ct[:, ii].reshape(S.ishape))
# operation below is effectively A = A + Bi * F(Cti * S)
if ii == 0:
A = Bi * F * S * Cti
else:
A = A + Bi * F * S * Cti
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(F.oshape, weights**0.5)
A = P * A
if comm is not None:
C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True)
A = A * C
A.repr_str = 'Sense'
return A
[docs]def ConvSense(img_ker_shape, mps_ker, coord=None, weights=None, grd_shape=None,
comm=None):
"""Convolution linear operator with sensitivity maps kernel in k-space.
Args:
img_ker_shape (tuple of ints): image kernel shape.
mps_ker (array): sensitivity maps kernel.
coord (array): coordinates.
grd_shape (None or list): Shape of grid.
"""
ndim = len(img_ker_shape)
num_coils = mps_ker.shape[0]
mps_ker = mps_ker.reshape((num_coils, 1) + mps_ker.shape[1:])
R = sp.linop.Reshape((1, ) + tuple(img_ker_shape), img_ker_shape)
C = sp.linop.ConvolveData(R.oshape, mps_ker,
mode='valid', multi_channel=True)
A = C * R
if coord is not None:
if grd_shape is None:
grd_shape = sp.estimate_shape(coord)
else:
grd_shape = list(grd_shape)
grd_shape = [num_coils] + grd_shape
iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0))
N = sp.linop.NUFFT(grd_shape, coord)
A = N * iF * A
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(A.oshape, weights**0.5)
A = P * A
if comm is not None:
C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True)
A = A * C
return A
[docs]def ConvImage(mps_ker_shape, img_ker, coord=None, weights=None,
grd_shape=None):
"""Convolution linear operator with image kernel in k-space.
Args:
mps_ker_shape (tuple of ints): sensitivity maps kernel shape.
img_ker (array): image kernel.
coord (array): coordinates.
grd_shape (None or list): Shape of grid.
"""
ndim = img_ker.ndim
num_coils = mps_ker_shape[0]
img_ker = img_ker.reshape((1, ) + img_ker.shape)
R = sp.linop.Reshape((num_coils, 1) + tuple(mps_ker_shape[1:]),
mps_ker_shape)
C = sp.linop.ConvolveFilter(R.oshape, img_ker,
mode='valid', multi_channel=True)
A = C * R
if coord is not None:
num_coils = mps_ker_shape[0]
if grd_shape is None:
grd_shape = sp.estimate_shape(coord)
else:
grd_shape = list(grd_shape)
grd_shape = [num_coils] + grd_shape
iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0))
N = sp.linop.NUFFT(grd_shape, coord)
A = N * iF * A
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(A.oshape, weights**0.5)
A = P * A
return A