#!/usr/bin/python
"""Classes and functions for fitting the mean signal diffusion kurtosis
model"""
import numpy as np
import scipy.optimize as opt
from dipy.core.gradients import check_multi_b, round_bvals, unique_bvals_magnitude
from dipy.core.ndindex import ndindex
from dipy.core.onetime import auto_attr
from dipy.reconst.base import ReconstModel
from dipy.reconst.dti import MIN_POSITIVE_SIGNAL
from dipy.testing.decorators import warning_for_keywords
[docs]
@warning_for_keywords()
def mean_signal_bvalue(data, gtab, *, bmag=None):
"""
Computes the average signal across different diffusion directions
for each unique b-value
Parameters
----------
data : ndarray ([X, Y, Z, ...], g)
ndarray containing the data signals in its last dimension.
gtab : a GradientTable class instance
The gradient table containing diffusion acquisition parameters.
bmag : The order of magnitude that the bvalues have to differ to be
considered an unique b-value. Default: derive this value from the
maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
Returns
-------
msignal : ndarray ([X, Y, Z, ..., nub])
Mean signal along all gradient directions for each unique b-value
Note that the last dimension contains the signal means and nub is the
number of unique b-values.
ng : ndarray(nub)
Number of gradient directions used to compute the mean signal for
all unique b-values
Notes
-----
This function assumes that directions are evenly sampled on the sphere or
on the hemisphere
"""
bvals = gtab.bvals.copy()
# Compute unique and rounded bvals
ub, rb = unique_bvals_magnitude(bvals, bmag=bmag, rbvals=True)
# Initialize msignal and ng
nub = ub.size
ng = np.zeros(nub)
msignal = np.zeros(data.shape[:-1] + (nub,))
for bi in range(ub.size):
msignal[..., bi] = np.mean(data[..., rb == ub[bi]], axis=-1)
ng[bi] = np.sum(rb == ub[bi])
return msignal, ng
[docs]
def msk_from_awf(f):
"""
Computes mean signal kurtosis from axonal water fraction estimates of the
SMT2 model
Parameters
----------
f : ndarray ([X, Y, Z, ...])
ndarray containing the axonal volume fraction estimate.
Returns
-------
msk : ndarray(nub)
Mean signal kurtosis (msk)
Notes
-----
Computes mean signal kurtosis using equations 17 of
:footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
msk_num = 216 * f - 504 * f**2 + 504 * f**3 - 180 * f**4
msk_den = 135 - 360 * f + 420 * f**2 - 240 * f**3 + 60 * f**4
msk = msk_num / msk_den
return msk
def _msk_from_awf_error(f, msk):
"""Helper function that calculates the error of a predicted mean signal
kurtosis from the axonal water fraction of SMT2 model and a measured
mean signal kurtosis
Parameters
----------
f : float
Axonal volume fraction estimate.
msk : float
Measured mean signal kurtosis.
Returns
-------
error : float
Error computed by subtracting msk with fun(f), where fun is the function
described in equation 17 of :footcite:p:`NetoHenriques2019`.
Notes
-----
This function corresponds to the differential of equations 17 of
:footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
return msk_from_awf(f) - msk
def _diff_msk_from_awf(f, msk):
"""
Helper function that calculates differential of function msk_from_awf
Parameters
----------
f : ndarray ([X, Y, Z, ...])
ndarray containing the axonal volume fraction estimate.
Returns
-------
dkdf : ndarray(nub)
Mean signal kurtosis differential
msk : float
Measured mean signal kurtosis.
Notes
-----
This function corresponds to the differential of equations 17 of
:footcite:p:`NetoHenriques2019`.
This function is applicable to both _msk_from_awf and _msk_from_awf_error.
References
----------
.. footbibliography::
"""
F = 216 * f - 504 * f**2 + 504 * f**3 - 180 * f**4 # Numerator
G = 135 - 360 * f + 420 * f**2 - 240 * f**3 + 60 * f**4 # Denominator
dF = 216 - 1008 * f + 1512 * f**2 - 720 * f**3 # Num. differential
dG = -360 + 840 * f - 720 * f**2 + 240 * f**3 # Den. differential
return (G * dF - F * dG) / (G**2)
[docs]
@warning_for_keywords()
def awf_from_msk(msk, *, mask=None):
"""
Computes the axonal water fraction from the mean signal kurtosis
assuming the 2-compartmental spherical mean technique model.
See :footcite:p:`Kaden2016b` and :footcite:p:`NetoHenriques2019` for further
details about the method.
Parameters
----------
msk : ndarray ([X, Y, Z, ...])
Mean signal kurtosis (msk)
mask : ndarray, optional
A boolean array used to mark the coordinates in the data that should be
analyzed that has the same shape of the msdki parameters
Returns
-------
awf : ndarray ([X, Y, Z, ...])
ndarray containing the axonal volume fraction estimate.
Notes
-----
Computes the axonal water fraction from the mean signal kurtosis
MSK using equation 17 of :footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
awf = np.zeros(msk.shape)
# Prepare mask
if mask is None:
mask = np.ones(msk.shape, dtype=bool)
else:
if mask.shape != msk.shape:
raise ValueError("Mask is not the same shape as data.")
mask = np.asarray(mask, dtype=bool)
# looping voxels
index = ndindex(mask.shape)
for v in index:
# Skip if out of mask
if not mask[v]:
continue
if msk[v] > 2.4:
awf[v] = 1
elif msk[v] < 0:
awf[v] = 0
else:
if np.isnan(msk[v]):
awf[v] = np.nan
else:
mski = msk[v]
fini = mski / 2.4 # Initial guess based on linear assumption
awf[v] = opt.fsolve(
_msk_from_awf_error,
fini,
args=(mski,),
fprime=_diff_msk_from_awf,
col_deriv=True,
).item()
return awf
[docs]
@warning_for_keywords()
def msdki_prediction(msdki_params, gtab, *, S0=1.0):
"""
Predict the mean signal given the parameters of the mean signal DKI, an
GradientTable object and S0 signal.
See :footcite:p:`NetoHenriques2018` for further details about the method.
Parameters
----------
msdki_params : ndarray ([X, Y, Z, ...], 2)
Array containing the mean signal diffusivity and mean signal kurtosis
in its last axis
gtab : a GradientTable class instance
The gradient table for this prediction
S0 : float or ndarray, optional
The non diffusion-weighted signal in every voxel, or across all
voxels.
Notes
-----
The predicted signal is given by:
$MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $K$ are the
mean signal diffusivity and mean signal kurtosis.
References
----------
.. footbibliography::
"""
A = design_matrix(round_bvals(gtab.bvals))
params = msdki_params.copy()
params[..., 1] = params[..., 1] * params[..., 0] ** 2
if isinstance(S0, (float, int)):
pred_sig = S0 * np.exp(np.dot(params, A[:, :2].T))
elif S0.size == 1:
pred_sig = S0 * np.exp(np.dot(params, A[:, :2].T))
else:
nv = gtab.bvals.size
S0r = np.zeros(S0.shape + gtab.bvals.shape)
for vi in range(nv):
S0r[..., vi] = S0
pred_sig = S0r * np.exp(np.dot(params, A[:, :2].T))
return pred_sig
[docs]
class MeanDiffusionKurtosisModel(ReconstModel):
"""Mean signal Diffusion Kurtosis Model"""
def __init__(self, gtab, *args, bmag=None, return_S0_hat=False, **kwargs):
"""Mean Signal Diffusion Kurtosis Model.
See :footcite:p:`NetoHenriques2018` for further details about the model.
Parameters
----------
gtab : GradientTable class instance
Gradient table.
bmag : int, optional
The order of magnitude that the bvalues have to differ to be
considered an unique b-value. Default: derive this value from the
maximal b-value provided: $bmag=log_{10}(max(bvals)) - 1$.
return_S0_hat : bool, optional
If True, also return S0 values for the fit.
args, kwargs : arguments and keyword arguments passed to the
fit_method. See msdki.wls_fit_msdki for details
References
----------
.. footbibliography::
"""
ReconstModel.__init__(self, gtab)
self.return_S0_hat = return_S0_hat
self.ubvals = unique_bvals_magnitude(gtab.bvals, bmag=bmag)
self.design_matrix = design_matrix(self.ubvals)
self.bmag = bmag
self.args = args
self.kwargs = kwargs
self.min_signal = self.kwargs.pop("min_signal", MIN_POSITIVE_SIGNAL)
if self.min_signal is not None and self.min_signal <= 0:
e_s = "The `min_signal` key-word argument needs to be strictly"
e_s += " positive."
raise ValueError(e_s)
# Check if at least three b-values are given
enough_b = check_multi_b(self.gtab, 3, non_zero=False, bmag=bmag)
if not enough_b:
mes = "MSDKI requires at least 3 b-values (which can include b=0)"
raise ValueError(mes)
[docs]
@warning_for_keywords()
def fit(self, data, *, mask=None):
"""Fit method of the MSDKI model class
Parameters
----------
data : ndarray ([X, Y, Z, ...], g)
ndarray containing the data signals in its last dimension.
mask : array
A boolean array used to mark the coordinates in the data that
should be analyzed that has the shape data.shape[:-1]
"""
S0_params = None
# Compute mean signal for each unique b-value
mdata, ng = mean_signal_bvalue(data, self.gtab, bmag=self.bmag)
# Remove mdata zeros
mdata = np.maximum(mdata, self.min_signal)
params = wls_fit_msdki(
self.design_matrix,
mdata,
ng,
mask=mask,
return_S0_hat=self.return_S0_hat,
*self.args,
**self.kwargs,
)
if self.return_S0_hat:
params, S0_params = params
return MeanDiffusionKurtosisFit(self, params, model_S0=S0_params)
[docs]
@warning_for_keywords()
def predict(self, msdki_params, *, S0=1.0):
"""
Predict a signal for this MeanDiffusionKurtosisModel class instance
given parameters.
See :footcite:p:`NetoHenriques2018` for further details about the
method.
Parameters
----------
msdki_params : ndarray
The parameters of the mean signal diffusion kurtosis model
S0 : float or ndarray, optional
The non diffusion-weighted signal in every voxel, or across all
voxels.
Returns
-------
S : (..., N) ndarray
Simulated mean signal based on the mean signal diffusion kurtosis
model
Notes
-----
The predicted signal is given by:
$MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $K$ are
the mean signal diffusivity and mean signal kurtosis.
References
----------
.. footbibliography::
"""
return msdki_prediction(msdki_params, self.gtab, S0=S0)
[docs]
class MeanDiffusionKurtosisFit:
@warning_for_keywords()
def __init__(self, model, model_params, *, model_S0=None):
"""Initialize a MeanDiffusionKurtosisFit class instance."""
self.model = model
self.model_params = model_params
self.model_S0 = model_S0
def __getitem__(self, index):
model_params = self.model_params
model_S0 = self.model_S0
N = model_params.ndim
if type(index) is not tuple:
index = (index,)
elif len(index) >= model_params.ndim:
raise IndexError("IndexError: invalid index")
index = index + (slice(None),) * (N - len(index))
if model_S0 is not None:
model_S0 = model_S0[index[:-1]]
return MeanDiffusionKurtosisFit(
self.model, model_params[index], model_S0=model_S0
)
@property
def S0_hat(self):
return self.model_S0
[docs]
@auto_attr
def msd(self):
r"""
Mean signal diffusivity (MSD) calculated from the mean signal
Diffusion Kurtosis Model.
See :footcite:p:`NetoHenriques2018` for further details about the
method.
Returns
-------
msd : ndarray
Calculated signal mean diffusivity.
References
----------
.. footbibliography::
"""
return self.model_params[..., 0]
[docs]
@auto_attr
def msk(self):
r"""
Mean signal kurtosis (MSK) calculated from the mean signal
Diffusion Kurtosis Model.
See :footcite:p:`NetoHenriques2018` for further details about the
method.
Returns
-------
msk : ndarray
Calculated signal mean kurtosis.
References
----------
.. footbibliography::
"""
return self.model_params[..., 1]
[docs]
@auto_attr
def smt2f(self):
r"""
Computes the axonal water fraction from the mean signal kurtosis
assuming the 2-compartmental spherical mean technique model.
See :footcite:p:`Kaden2016b` and :footcite:p:`NetoHenriques2019` for
further details about the method.
Returns
-------
ndarray
Axonal volume fraction calculated from MSK.
Notes
-----
Computes the axonal water fraction from the mean signal kurtosis
MSK using equation 17 of :footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
return awf_from_msk(self.msk)
[docs]
@auto_attr
def smt2di(self):
r"""
Computes the intrinsic diffusivity from the mean signal diffusional
kurtosis parameters assuming the 2-compartmental spherical mean
technique model.
See :footcite:p:`Kaden2016b` and :footcite:p:`NetoHenriques2019` for
further details about the method.
Returns
-------
smt2di : ndarray
Intrinsic diffusivity computed by converting MSDKI to SMT2.
Notes
-----
Computes the intrinsic diffusivity using equation 16 of
:footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
return 3 * self.msd / (1 + 2 * (1 - self.smt2f) ** 2)
[docs]
@auto_attr
def smt2uFA(self):
r"""
Computes the microscopic fractional anisotropy from the mean signal
diffusional kurtosis parameters assuming the 2-compartmental spherical
mean technique model.
See :footcite:p:`Kaden2016b` and :footcite:p:`NetoHenriques2019` for
further details about the method.
Returns
-------
smt2uFA : ndarray
Microscopic fractional anisotropy computed by converting MSDKI to
SMT2.
Notes
-----
Computes the intrinsic diffusivity using equation 10 of
:footcite:p:`NetoHenriques2019`.
References
----------
.. footbibliography::
"""
fe = 1 - self.smt2f
num = 3 * (1 - 2 * fe**2 + fe**3)
den = 3 + 2 * fe**3 + 4 * fe**4
return np.sqrt(num / den)
[docs]
@warning_for_keywords()
def predict(self, gtab, *, S0=1.0):
r"""
Given a mean signal diffusion kurtosis model fit, predict the signal
on the vertices of a sphere
See :footcite:p:`NetoHenriques2018` for further details about the
method.
Parameters
----------
gtab : a GradientTable class instance
This encodes the directions for which a prediction is made
S0 : float array
The mean non-diffusion weighted signal in each voxel. Default:
The fitted S0 value in all voxels if it was fitted. Otherwise 1 in
all voxels.
Returns
-------
S : (..., N) ndarray
Simulated mean signal based on the mean signal kurtosis model
Notes
-----
The predicted signal is given by:
$MS(b) = S_0 * exp(-bD + 1/6 b^{2} D^{2} K)$, where $D$ and $k$ are the
mean signal diffusivity and mean signal kurtosis.
References
----------
.. footbibliography::
"""
return msdki_prediction(self.model_params, gtab, S0=S0)
[docs]
@warning_for_keywords()
def wls_fit_msdki(
design_matrix,
msignal,
ng,
*,
mask=None,
min_signal=MIN_POSITIVE_SIGNAL,
return_S0_hat=False,
):
r"""
Fits the mean signal diffusion kurtosis imaging based on a weighted
least square solution.
See :footcite:p:`NetoHenriques2018` for further details about the method.
Parameters
----------
design_matrix : array (nub, 3)
Design matrix holding the covariants used to solve for the regression
coefficients of the mean signal diffusion kurtosis model. Note that
nub is the number of unique b-values
msignal : ndarray ([X, Y, Z, ..., nub])
Mean signal along all gradient directions for each unique b-value
Note that the last dimension should contain the signal means and nub
is the number of unique b-values.
ng : ndarray(nub)
Number of gradient directions used to compute the mean signal for
all unique b-values
mask : array
A boolean array used to mark the coordinates in the data that
should be analyzed that has the shape data.shape[:-1]
min_signal : float, optional
Voxel with mean signal intensities lower than the min positive signal
are not processed. Default: 0.0001
return_S0_hat : bool
If True, also return S0 values for the fit.
Returns
-------
params : array (..., 2)
Containing the mean signal diffusivity and mean signal kurtosis
References
----------
.. footbibliography::
"""
params = np.zeros(msignal.shape[:-1] + (3,))
# Prepare mask
if mask is None:
mask = np.ones(msignal.shape[:-1], dtype=bool)
else:
if mask.shape != msignal.shape[:-1]:
raise ValueError("Mask is not the same shape as data.")
mask = np.asarray(mask, dtype=bool)
index = ndindex(mask.shape)
for v in index:
# Skip if out of mask
if not mask[v]:
continue
# Skip if no signal is present
if np.mean(msignal[v]) <= min_signal:
continue
# Define weights as diag(ng * yn**2)
W = np.diag(ng * msignal[v] ** 2)
# WLS fitting
BTW = np.dot(design_matrix.T, W)
inv_BT_W_B = np.linalg.pinv(np.dot(BTW, design_matrix))
p = np.linalg.multi_dot([inv_BT_W_B, BTW, np.log(msignal[v])])
# Process parameters
p[1] = p[1] / (p[0] ** 2)
p[2] = np.exp(p[2])
params[v] = p
if return_S0_hat:
return params[..., :2], params[..., 2]
else:
return params[..., :2]
[docs]
def design_matrix(ubvals):
"""Constructs design matrix for the mean signal diffusion kurtosis model
Parameters
----------
ubvals : array
Containing the unique b-values of the data.
Returns
-------
design_matrix : array (nb, 3)
Design matrix or B matrix for the mean signal diffusion kurtosis
model assuming that parameters are in the following order:
design_matrix[j, :] = (msd, msk, S0)
"""
nb = ubvals.shape
B = np.zeros(nb + (3,))
B[:, 0] = -ubvals
B[:, 1] = 1.0 / 6.0 * ubvals**2
B[:, 2] = np.ones(nb)
return B