# -*- coding: utf-8 -*-
"""MRI linear operators.
This module mainly contains the Sense linear operator,
which integrates multi-channel coil sensitivity maps and
discrete Fourier transform.
"""
import sigpy as sp
[docs]def Sense(
mps,
coord=None,
weights=None,
tseg=None,
ishape=None,
coil_batch_size=None,
comm=None,
transp_nufft=False,
):
"""Sense linear operator.
Args:
mps (array): sensitivity maps of length = number of channels.
coord (None or array): coordinates.
weights (None or array): k-space weights.
Useful for soft-gating or density compensation.
tseg (None or Dictionary): parameters for time-segmented off-resonance
correction. Parameters are 'b0' (array), 'dt' (float),
'lseg' (int), and 'n_bins' (int). Lseg is the number of
time segments used, and n_bins is the number of histogram bins.
ishape (None or tuple): image shape.
coil_batch_size (None or int): batch size for processing multi-channel.
When None, process all coils at the same time.
Useful for saving memory.
comm (None or `sigpy.Communicator`): communicator
for distributed computing.
"""
# Get image shape and dimension.
num_coils = len(mps)
if ishape is None:
ishape = mps.shape[1:]
img_ndim = mps.ndim - 1
else:
img_ndim = len(ishape)
# Serialize linop if coil_batch_size is smaller than num_coils.
num_coils = len(mps)
if coil_batch_size is None:
coil_batch_size = num_coils
if coil_batch_size < len(mps):
num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size
A = sp.linop.Vstack(
[
Sense(
mps[c * coil_batch_size : ((c + 1) * coil_batch_size)],
coord=coord,
weights=weights,
ishape=ishape,
)
for c in range(num_coil_batches)
],
axis=0,
)
if comm is not None:
C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True)
A = A * C
return A
# Create Sense linear operator
S = sp.linop.Multiply(ishape, mps)
if tseg is None:
if coord is None:
F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0))
else:
if transp_nufft is False:
F = sp.linop.NUFFT(S.oshape, coord)
else:
F = sp.linop.NUFFT(S.oshape, -coord).H
A = F * S
# If B0 provided, perform time-segmented off-resonance compensation
else:
if transp_nufft is False:
F = sp.linop.NUFFT(S.oshape, coord)
else:
F = sp.linop.NUFFT(S.oshape, -coord).H
time = len(coord) * tseg["dt"]
b, ct = sp.mri.util.tseg_off_res_b_ct(
tseg["b0"], tseg["n_bins"], tseg["lseg"], tseg["dt"], time
)
for ii in range(tseg["lseg"]):
Bi = sp.linop.Multiply(F.oshape, b[:, ii])
Cti = sp.linop.Multiply(S.ishape, ct[:, ii].reshape(S.ishape))
# operation below is effectively A = A + Bi * F(Cti * S)
if ii == 0:
A = Bi * F * S * Cti
else:
A = A + Bi * F * S * Cti
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(F.oshape, weights**0.5)
A = P * A
if comm is not None:
C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True)
A = A * C
A.repr_str = "Sense"
return A
[docs]def ConvSense(
img_ker_shape, mps_ker, coord=None, weights=None, grd_shape=None, comm=None
):
"""Convolution linear operator with sensitivity maps kernel in k-space.
Args:
img_ker_shape (tuple of ints): image kernel shape.
mps_ker (array): sensitivity maps kernel.
coord (array): coordinates.
grd_shape (None or list): Shape of grid.
"""
ndim = len(img_ker_shape)
num_coils = mps_ker.shape[0]
mps_ker = mps_ker.reshape((num_coils, 1) + mps_ker.shape[1:])
R = sp.linop.Reshape((1,) + tuple(img_ker_shape), img_ker_shape)
C = sp.linop.ConvolveData(
R.oshape, mps_ker, mode="valid", multi_channel=True
)
A = C * R
if coord is not None:
if grd_shape is None:
grd_shape = sp.estimate_shape(coord)
else:
grd_shape = list(grd_shape)
grd_shape = [num_coils] + grd_shape
iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0))
N = sp.linop.NUFFT(grd_shape, coord)
A = N * iF * A
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(A.oshape, weights**0.5)
A = P * A
if comm is not None:
C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True)
A = A * C
return A
[docs]def ConvImage(
mps_ker_shape, img_ker, coord=None, weights=None, grd_shape=None
):
"""Convolution linear operator with image kernel in k-space.
Args:
mps_ker_shape (tuple of ints): sensitivity maps kernel shape.
img_ker (array): image kernel.
coord (array): coordinates.
grd_shape (None or list): Shape of grid.
"""
ndim = img_ker.ndim
num_coils = mps_ker_shape[0]
img_ker = img_ker.reshape((1,) + img_ker.shape)
R = sp.linop.Reshape(
(num_coils, 1) + tuple(mps_ker_shape[1:]), mps_ker_shape
)
C = sp.linop.ConvolveFilter(
R.oshape, img_ker, mode="valid", multi_channel=True
)
A = C * R
if coord is not None:
num_coils = mps_ker_shape[0]
if grd_shape is None:
grd_shape = sp.estimate_shape(coord)
else:
grd_shape = list(grd_shape)
grd_shape = [num_coils] + grd_shape
iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0))
N = sp.linop.NUFFT(grd_shape, coord)
A = N * iF * A
if weights is not None:
with sp.get_device(weights):
P = sp.linop.Multiply(A.oshape, weights**0.5)
A = P * A
return A