from warnings import warn
import numpy as np
from scipy.optimize import leastsq
from scipy.special import gamma, hyp1f1
from dipy.core.geometry import cart2sphere
from dipy.data import default_sphere
from dipy.reconst.cache import Cache
from dipy.reconst.csdeconv import csdeconv
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.odf import OdfFit, OdfModel
from dipy.reconst.shm import real_sh_descoteaux_from_index
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package
cvxpy, have_cvxpy, _ = optional_package("cvxpy", min_version="1.4.1")
[docs]
class ForecastModel(OdfModel, Cache):
r"""Fiber ORientation Estimated using Continuous Axially Symmetric Tensors
(FORECAST).
FORECAST :footcite:p:`Anderson2005`, :footcite:p:`Kaden2016a`,
:footcite:p:`Zucchelli2017` is a Spherical Deconvolution reconstruction
model for multi-shell diffusion data which enables the calculation of a
voxel adaptive response function using the Spherical Mean Technique (SMT)
:footcite:p:`Kaden2016a`, :footcite:p:`Zucchelli2017`.
With FORECAST it is possible to calculate crossing invariant parallel
diffusivity, perpendicular diffusivity, mean diffusivity, and fractional
anisotropy :footcite:p:`Kaden2016a`.
References
----------
.. footbibliography::
Notes
-----
The implementation of FORECAST may require CVXPY (https://www.cvxpy.org/).
"""
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def __init__(
self,
gtab,
*,
sh_order_max=8,
lambda_lb=1e-3,
dec_alg="CSD",
sphere=None,
lambda_csd=1.0,
):
r"""Analytical and continuous modeling of the diffusion signal with
respect to the FORECAST basis.
This implementation is a modification of the original FORECAST
model presented in :footcite:p:`Anderson2005` adapted for multi-shell
data as in :footcite:p:`Kaden2016a`, :footcite:p:`Zucchelli2017`.
The main idea is to model the diffusion signal as the combination of a
single fiber response function $F(\mathbf{b})$ times the fODF
$\rho(\mathbf{v})$
.. math::
E(\mathbf{b}) = \int_{\mathbf{v} \in \mathcal{S}^2} \rho(\mathbf{v}) F({\mathbf{b}} | \mathbf{v}) d \mathbf{v}
where $\mathbf{b}$ is the b-vector (b-value times gradient direction)
and $\mathbf{v}$ is a unit vector representing a fiber direction.
In FORECAST $\rho$ is modeled using real symmetric Spherical Harmonics
(SH) and $F(\mathbf(b))$ is an axially symmetric tensor.
Parameters
----------
gtab : GradientTable,
gradient directions and bvalues container class.
sh_order_max : unsigned int, optional
an even integer that represent the maximal SH order ($l$) of the
basis (max 12)
lambda_lb: float, optional
Laplace-Beltrami regularization weight.
dec_alg : str, optional
Spherical deconvolution algorithm. The possible values are Weighted Least Squares ('WLS'),
Positivity Constraints using CVXPY ('POS') and the Constraint
Spherical Deconvolution algorithm ('CSD'). Default is 'CSD'.
sphere : array, shape (N,3), optional
sphere points where to enforce positivity when 'POS' or 'CSD'
dec_alg are selected.
lambda_csd : float, optional
CSD regularization weight.
References
----------
.. footbibliography::
Examples
--------
In this example, where the data, gradient table and sphere tessellation
used for reconstruction are provided, we model the diffusion signal
with respect to the FORECAST and compute the fODF, parallel and
perpendicular diffusivity.
>>> import warnings
>>> from dipy.data import default_sphere, get_3shell_gtab
>>> gtab = get_3shell_gtab()
>>> from dipy.sims.voxel import multi_tensor
>>> mevals = np.array(([0.0017, 0.0003, 0.0003],
... [0.0017, 0.0003, 0.0003]))
>>> angl = [(0, 0), (60, 0)]
>>> data, sticks = multi_tensor(gtab,
... mevals,
... S0=100.0,
... angles=angl,
... fractions=[50, 50],
... snr=None)
>>> from dipy.reconst.forecast import ForecastModel
>>> from dipy.reconst.shm import descoteaux07_legacy_msg
>>> with warnings.catch_warnings():
... warnings.filterwarnings(
... "ignore", message=descoteaux07_legacy_msg,
... category=PendingDeprecationWarning)
... fm = ForecastModel(gtab, sh_order_max=6)
>>> f_fit = fm.fit(data)
>>> d_par = f_fit.dpar
>>> d_perp = f_fit.dperp
>>> with warnings.catch_warnings():
... warnings.filterwarnings(
... "ignore", message=descoteaux07_legacy_msg,
... category=PendingDeprecationWarning)
... fodf = f_fit.odf(default_sphere)
""" # noqa: E501
OdfModel.__init__(self, gtab)
# round the bvals in order to avoid numerical errors
self.bvals = np.round(gtab.bvals / 100) * 100
self.bvecs = gtab.bvecs
if 0 <= sh_order_max <= 12 and not bool(sh_order_max % 2):
self.sh_order_max = sh_order_max
else:
msg = "sh_order_max must be a non-zero even positive number "
msg += "between 2 and 12"
raise ValueError(msg)
if sphere is None:
sphere = default_sphere
self.vertices = sphere.vertices[0 : int(sphere.vertices.shape[0] / 2), :]
else:
self.vertices = sphere
self.b0s_mask = self.bvals == 0
self.one_0_bvals = np.r_[0, self.bvals[~self.b0s_mask]]
self.one_0_bvecs = np.r_[
np.array([0, 0, 0]).reshape(1, 3), self.bvecs[~self.b0s_mask, :]
]
self.rho = rho_matrix(self.sh_order_max, self.one_0_bvecs)
# signal regularization matrix
self.srm = rho_matrix(4, self.one_0_bvecs)
self.lb_matrix_signal = lb_forecast(4)
self.b_unique = np.sort(np.unique(self.bvals[self.bvals > 0]))
self.wls = True
self.csd = False
self.pos = False
if dec_alg.upper() == "POS":
if not have_cvxpy:
cvxpy.import_error()
self.wls = False
self.pos = True
if dec_alg.upper() == "CSD":
self.csd = True
self.lb_matrix = lb_forecast(self.sh_order_max)
self.lambda_lb = lambda_lb
self.lambda_csd = lambda_csd
self.fod = rho_matrix(sh_order_max, self.vertices)
@multi_voxel_fit
def fit(self, data, **kwargs):
data_b0 = data[self.b0s_mask].mean()
data_single_b0 = np.r_[data_b0, data[~self.b0s_mask]] / data_b0
# calculates the mean signal at each b_values
means = find_signal_means(
self.b_unique,
data_single_b0,
self.one_0_bvals,
self.srm,
self.lb_matrix_signal,
)
# average diffusivity initialization
x = np.array([np.pi / 4, np.pi / 4])
x, status = leastsq(forecast_error_func, x, args=(self.b_unique, means))
# transform to bound the diffusivities from 0 to 3e-03
d_par = np.cos(x[0]) ** 2 * 3e-03
d_perp = np.cos(x[1]) ** 2 * 3e-03
if d_perp >= d_par:
d_par, d_perp = d_perp, d_par
# round to avoid memory explosion
diff_key = str(int(np.round(d_par * 1e05))) + str(int(np.round(d_perp * 1e05)))
M_diff = self.cache_get("forecast_matrix", key=diff_key)
if M_diff is None:
M_diff = forecast_matrix(self.sh_order_max, d_par, d_perp, self.one_0_bvals)
self.cache_set("forecast_matrix", key=diff_key, value=M_diff)
M = M_diff * self.rho
M0 = M[:, 0]
c0 = np.sqrt(1.0 / (4 * np.pi))
# coefficients vector initialization
n_c = int((self.sh_order_max + 1) * (self.sh_order_max + 2) / 2)
coef = np.zeros(n_c)
coef[0] = c0
if int(np.round(d_par * 1e05)) > int(np.round(d_perp * 1e05)):
if self.wls:
data_r = data_single_b0 - M0 * c0
Mr = M[:, 1:]
Lr = self.lb_matrix[1:, 1:]
pseudo_inv = np.dot(
np.linalg.inv(np.dot(Mr.T, Mr) + self.lambda_lb * Lr), Mr.T
)
coef = np.dot(pseudo_inv, data_r)
coef = np.r_[c0, coef]
if self.csd:
coef, _ = csdeconv(data_single_b0, M, self.fod, tau=0.1, convergence=50)
coef = coef / coef[0] * c0
if self.pos:
c = cvxpy.Variable(M.shape[1])
design_matrix = cvxpy.Constant(M) @ c
objective = cvxpy.Minimize(
cvxpy.sum_squares(design_matrix - data_single_b0)
+ self.lambda_lb * cvxpy.quad_form(c, self.lb_matrix)
)
constraints = [c[0] == c0, self.fod @ c >= 0]
prob = cvxpy.Problem(objective, constraints)
try:
prob.solve(solver=cvxpy.OSQP, eps_abs=1e-05, eps_rel=1e-05)
coef = np.asarray(c.value).squeeze()
except Exception:
warn("Optimization did not find a solution", stacklevel=2)
coef = np.zeros(M.shape[1])
coef[0] = c0
return ForecastFit(self, data, coef, d_par, d_perp)
[docs]
class ForecastFit(OdfFit):
def __init__(self, model, data, sh_coef, d_par, d_perp):
"""Calculates diffusion properties for a single voxel
Parameters
----------
model : object,
AnalyticalModel
data : 1d ndarray,
fitted data
sh_coef : 1d ndarray,
forecast sh coefficients
d_par : float,
parallel diffusivity
d_perp : float,
perpendicular diffusivity
"""
OdfFit.__init__(self, model, data)
self.model = model
self._sh_coef = sh_coef
self.gtab = model.gtab
self.sh_order_max = model.sh_order_max
self.d_par = d_par
self.d_perp = d_perp
self.rho = None
[docs]
@warning_for_keywords()
def odf(self, sphere, *, clip_negative=True):
r"""Calculates the fODF for a given discrete sphere.
Parameters
----------
sphere : Sphere,
the odf sphere
clip_negative : boolean, optional
if True clip the negative odf values to 0, default True
"""
if self.rho is None:
self.rho = rho_matrix(self.sh_order_max, sphere.vertices)
odf = np.dot(self.rho, self._sh_coef)
if clip_negative:
odf = np.clip(odf, 0, odf.max())
return odf
[docs]
def fractional_anisotropy(self):
r"""Calculates the fractional anisotropy."""
fa = np.sqrt(
0.5
* (2 * (self.d_par - self.d_perp) ** 2)
/ (self.d_par**2 + 2 * self.d_perp**2)
)
return fa
[docs]
def mean_diffusivity(self):
r"""Calculates the mean diffusivity."""
md = (self.d_par + 2 * self.d_perp) / 3.0
return md
[docs]
@warning_for_keywords()
def predict(self, *, gtab=None, S0=1.0):
r"""Calculates the fODF for a given discrete sphere.
Parameters
----------
gtab : GradientTable, optional
gradient directions and bvalues container class.
S0 : float, optional
the signal at b-value=0
"""
if gtab is None:
gtab = self.gtab
M_diff = forecast_matrix(self.sh_order_max, self.d_par, self.d_perp, gtab.bvals)
rho = rho_matrix(self.sh_order_max, gtab.bvecs)
M = M_diff * rho
S = S0 * np.dot(M, self._sh_coef)
return S
@property
def sh_coeff(self):
"""The FORECAST SH coefficients"""
return self._sh_coef
@property
def dpar(self):
"""The parallel diffusivity"""
return self.d_par
@property
def dperp(self):
"""The perpendicular diffusivity"""
return self.d_perp
[docs]
@warning_for_keywords()
def find_signal_means(b_unique, data_norm, bvals, rho, lb_matrix, *, w=1e-03):
r"""Calculate the mean signal for each shell.
Parameters
----------
b_unique : 1d ndarray,
unique b-values in a vector excluding zero
data_norm : 1d ndarray,
normalized diffusion signal
bvals : 1d ndarray,
the b-values
rho : 2d ndarray,
SH basis matrix for fitting the signal on each shell
lb_matrix : 2d ndarray,
Laplace-Beltrami regularization matrix
w : float,
weight for the Laplace-Beltrami regularization
Returns
-------
means : 1d ndarray
the average of the signal for each b-values
"""
lb = len(b_unique)
means = np.zeros(lb)
for u in range(lb):
ind = bvals == b_unique[u]
shell = data_norm[ind]
if np.sum(ind) > 20:
M = rho[ind, :]
coef = np.linalg.multi_dot(
[np.linalg.inv(np.dot(M.T, M) + w * lb_matrix), M.T, shell]
)
means[u] = coef[0] / np.sqrt(4 * np.pi)
else:
means[u] = shell.mean()
return means
[docs]
def forecast_error_func(x, b_unique, E):
r"""Calculates the difference between the mean signal calculated using
the parameter vector x and the average signal E using FORECAST and SMT
"""
d_par = np.cos(x[0]) ** 2 * 3e-03
d_perp = np.cos(x[1]) ** 2 * 3e-03
if d_perp >= d_par:
d_par, d_perp = d_perp, d_par
E_reconst = (
0.5 * np.exp(-b_unique * d_perp) * psi_l(0, (b_unique * (d_par - d_perp)))
)
v = E - E_reconst
return v
[docs]
def psi_l(ell, b):
n = ell // 2
v = (-b) ** n
v *= gamma(n + 1.0 / 2) / gamma(2 * n + 3.0 / 2)
v *= hyp1f1(n + 1.0 / 2, 2 * n + 3.0 / 2, -b)
return v
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
def forecast_matrix(sh_order_max, d_par, d_perp, bvals):
r"""Compute the FORECAST radial matrix"""
n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2)
M = np.zeros((bvals.shape[0], n_c))
counter = 0
for ell in range(0, sh_order_max + 1, 2):
for _ in range(-ell, ell + 1):
M[:, counter] = (
2
* np.pi
* np.exp(-bvals * d_perp)
* psi_l(ell, bvals * (d_par - d_perp))
)
counter += 1
return M
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
def rho_matrix(sh_order_max, vecs):
r"""Compute the SH matrix $\rho$"""
r, theta, phi = cart2sphere(vecs[:, 0], vecs[:, 1], vecs[:, 2])
theta[np.isnan(theta)] = 0
n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2)
rho = np.zeros((vecs.shape[0], n_c))
counter = 0
for l_values in range(0, sh_order_max + 1, 2):
for m_values in range(-l_values, l_values + 1):
rho[:, counter] = real_sh_descoteaux_from_index(
m_values, l_values, theta, phi
)
counter += 1
return rho
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
def lb_forecast(sh_order_max):
r"""Returns the Laplace-Beltrami regularization matrix for FORECAST"""
n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2)
diag_lb = np.zeros(n_c)
counter = 0
for j in range(0, sh_order_max + 1, 2):
stop = 2 * j + 1 + counter
diag_lb[counter:stop] = (j * (j + 1)) ** 2
counter = stop
return np.diag(diag_lb)