Source code for dipy.reconst.forecast

from warnings import warn

import numpy as np
from scipy.optimize import leastsq
from scipy.special import gamma, hyp1f1

from dipy.core.geometry import cart2sphere
from dipy.data import default_sphere
from dipy.reconst.cache import Cache
from dipy.reconst.csdeconv import csdeconv
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.odf import OdfFit, OdfModel
from dipy.reconst.shm import real_sh_descoteaux_from_index
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package

cvxpy, have_cvxpy, _ = optional_package("cvxpy", min_version="1.4.1")


[docs] class ForecastModel(OdfModel, Cache): r"""Fiber ORientation Estimated using Continuous Axially Symmetric Tensors (FORECAST). FORECAST :footcite:p:`Anderson2005`, :footcite:p:`Kaden2016a`, :footcite:p:`Zucchelli2017` is a Spherical Deconvolution reconstruction model for multi-shell diffusion data which enables the calculation of a voxel adaptive response function using the Spherical Mean Technique (SMT) :footcite:p:`Kaden2016a`, :footcite:p:`Zucchelli2017`. With FORECAST it is possible to calculate crossing invariant parallel diffusivity, perpendicular diffusivity, mean diffusivity, and fractional anisotropy :footcite:p:`Kaden2016a`. References ---------- .. footbibliography:: Notes ----- The implementation of FORECAST may require CVXPY (https://www.cvxpy.org/). """ @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") @warning_for_keywords() def __init__( self, gtab, *, sh_order_max=8, lambda_lb=1e-3, dec_alg="CSD", sphere=None, lambda_csd=1.0, ): r"""Analytical and continuous modeling of the diffusion signal with respect to the FORECAST basis. This implementation is a modification of the original FORECAST model presented in :footcite:p:`Anderson2005` adapted for multi-shell data as in :footcite:p:`Kaden2016a`, :footcite:p:`Zucchelli2017`. The main idea is to model the diffusion signal as the combination of a single fiber response function $F(\mathbf{b})$ times the fODF $\rho(\mathbf{v})$ .. math:: E(\mathbf{b}) = \int_{\mathbf{v} \in \mathcal{S}^2} \rho(\mathbf{v}) F({\mathbf{b}} | \mathbf{v}) d \mathbf{v} where $\mathbf{b}$ is the b-vector (b-value times gradient direction) and $\mathbf{v}$ is a unit vector representing a fiber direction. In FORECAST $\rho$ is modeled using real symmetric Spherical Harmonics (SH) and $F(\mathbf(b))$ is an axially symmetric tensor. Parameters ---------- gtab : GradientTable, gradient directions and bvalues container class. sh_order_max : unsigned int, optional an even integer that represent the maximal SH order ($l$) of the basis (max 12) lambda_lb: float, optional Laplace-Beltrami regularization weight. dec_alg : str, optional Spherical deconvolution algorithm. The possible values are Weighted Least Squares ('WLS'), Positivity Constraints using CVXPY ('POS') and the Constraint Spherical Deconvolution algorithm ('CSD'). Default is 'CSD'. sphere : array, shape (N,3), optional sphere points where to enforce positivity when 'POS' or 'CSD' dec_alg are selected. lambda_csd : float, optional CSD regularization weight. References ---------- .. footbibliography:: Examples -------- In this example, where the data, gradient table and sphere tessellation used for reconstruction are provided, we model the diffusion signal with respect to the FORECAST and compute the fODF, parallel and perpendicular diffusivity. >>> import warnings >>> from dipy.data import default_sphere, get_3shell_gtab >>> gtab = get_3shell_gtab() >>> from dipy.sims.voxel import multi_tensor >>> mevals = np.array(([0.0017, 0.0003, 0.0003], ... [0.0017, 0.0003, 0.0003])) >>> angl = [(0, 0), (60, 0)] >>> data, sticks = multi_tensor(gtab, ... mevals, ... S0=100.0, ... angles=angl, ... fractions=[50, 50], ... snr=None) >>> from dipy.reconst.forecast import ForecastModel >>> from dipy.reconst.shm import descoteaux07_legacy_msg >>> with warnings.catch_warnings(): ... warnings.filterwarnings( ... "ignore", message=descoteaux07_legacy_msg, ... category=PendingDeprecationWarning) ... fm = ForecastModel(gtab, sh_order_max=6) >>> f_fit = fm.fit(data) >>> d_par = f_fit.dpar >>> d_perp = f_fit.dperp >>> with warnings.catch_warnings(): ... warnings.filterwarnings( ... "ignore", message=descoteaux07_legacy_msg, ... category=PendingDeprecationWarning) ... fodf = f_fit.odf(default_sphere) """ # noqa: E501 OdfModel.__init__(self, gtab) # round the bvals in order to avoid numerical errors self.bvals = np.round(gtab.bvals / 100) * 100 self.bvecs = gtab.bvecs if 0 <= sh_order_max <= 12 and not bool(sh_order_max % 2): self.sh_order_max = sh_order_max else: msg = "sh_order_max must be a non-zero even positive number " msg += "between 2 and 12" raise ValueError(msg) if sphere is None: sphere = default_sphere self.vertices = sphere.vertices[0 : int(sphere.vertices.shape[0] / 2), :] else: self.vertices = sphere self.b0s_mask = self.bvals == 0 self.one_0_bvals = np.r_[0, self.bvals[~self.b0s_mask]] self.one_0_bvecs = np.r_[ np.array([0, 0, 0]).reshape(1, 3), self.bvecs[~self.b0s_mask, :] ] self.rho = rho_matrix(self.sh_order_max, self.one_0_bvecs) # signal regularization matrix self.srm = rho_matrix(4, self.one_0_bvecs) self.lb_matrix_signal = lb_forecast(4) self.b_unique = np.sort(np.unique(self.bvals[self.bvals > 0])) self.wls = True self.csd = False self.pos = False if dec_alg.upper() == "POS": if not have_cvxpy: cvxpy.import_error() self.wls = False self.pos = True if dec_alg.upper() == "CSD": self.csd = True self.lb_matrix = lb_forecast(self.sh_order_max) self.lambda_lb = lambda_lb self.lambda_csd = lambda_csd self.fod = rho_matrix(sh_order_max, self.vertices) @multi_voxel_fit def fit(self, data, **kwargs): data_b0 = data[self.b0s_mask].mean() data_single_b0 = np.r_[data_b0, data[~self.b0s_mask]] / data_b0 # calculates the mean signal at each b_values means = find_signal_means( self.b_unique, data_single_b0, self.one_0_bvals, self.srm, self.lb_matrix_signal, ) # average diffusivity initialization x = np.array([np.pi / 4, np.pi / 4]) x, status = leastsq(forecast_error_func, x, args=(self.b_unique, means)) # transform to bound the diffusivities from 0 to 3e-03 d_par = np.cos(x[0]) ** 2 * 3e-03 d_perp = np.cos(x[1]) ** 2 * 3e-03 if d_perp >= d_par: d_par, d_perp = d_perp, d_par # round to avoid memory explosion diff_key = str(int(np.round(d_par * 1e05))) + str(int(np.round(d_perp * 1e05))) M_diff = self.cache_get("forecast_matrix", key=diff_key) if M_diff is None: M_diff = forecast_matrix(self.sh_order_max, d_par, d_perp, self.one_0_bvals) self.cache_set("forecast_matrix", key=diff_key, value=M_diff) M = M_diff * self.rho M0 = M[:, 0] c0 = np.sqrt(1.0 / (4 * np.pi)) # coefficients vector initialization n_c = int((self.sh_order_max + 1) * (self.sh_order_max + 2) / 2) coef = np.zeros(n_c) coef[0] = c0 if int(np.round(d_par * 1e05)) > int(np.round(d_perp * 1e05)): if self.wls: data_r = data_single_b0 - M0 * c0 Mr = M[:, 1:] Lr = self.lb_matrix[1:, 1:] pseudo_inv = np.dot( np.linalg.inv(np.dot(Mr.T, Mr) + self.lambda_lb * Lr), Mr.T ) coef = np.dot(pseudo_inv, data_r) coef = np.r_[c0, coef] if self.csd: coef, _ = csdeconv(data_single_b0, M, self.fod, tau=0.1, convergence=50) coef = coef / coef[0] * c0 if self.pos: c = cvxpy.Variable(M.shape[1]) design_matrix = cvxpy.Constant(M) @ c objective = cvxpy.Minimize( cvxpy.sum_squares(design_matrix - data_single_b0) + self.lambda_lb * cvxpy.quad_form(c, self.lb_matrix) ) constraints = [c[0] == c0, self.fod @ c >= 0] prob = cvxpy.Problem(objective, constraints) try: prob.solve(solver=cvxpy.OSQP, eps_abs=1e-05, eps_rel=1e-05) coef = np.asarray(c.value).squeeze() except Exception: warn("Optimization did not find a solution", stacklevel=2) coef = np.zeros(M.shape[1]) coef[0] = c0 return ForecastFit(self, data, coef, d_par, d_perp)
[docs] class ForecastFit(OdfFit): def __init__(self, model, data, sh_coef, d_par, d_perp): """Calculates diffusion properties for a single voxel Parameters ---------- model : object, AnalyticalModel data : 1d ndarray, fitted data sh_coef : 1d ndarray, forecast sh coefficients d_par : float, parallel diffusivity d_perp : float, perpendicular diffusivity """ OdfFit.__init__(self, model, data) self.model = model self._sh_coef = sh_coef self.gtab = model.gtab self.sh_order_max = model.sh_order_max self.d_par = d_par self.d_perp = d_perp self.rho = None
[docs] @warning_for_keywords() def odf(self, sphere, *, clip_negative=True): r"""Calculates the fODF for a given discrete sphere. Parameters ---------- sphere : Sphere, the odf sphere clip_negative : boolean, optional if True clip the negative odf values to 0, default True """ if self.rho is None: self.rho = rho_matrix(self.sh_order_max, sphere.vertices) odf = np.dot(self.rho, self._sh_coef) if clip_negative: odf = np.clip(odf, 0, odf.max()) return odf
[docs] def fractional_anisotropy(self): r"""Calculates the fractional anisotropy.""" fa = np.sqrt( 0.5 * (2 * (self.d_par - self.d_perp) ** 2) / (self.d_par**2 + 2 * self.d_perp**2) ) return fa
[docs] def mean_diffusivity(self): r"""Calculates the mean diffusivity.""" md = (self.d_par + 2 * self.d_perp) / 3.0 return md
[docs] @warning_for_keywords() def predict(self, *, gtab=None, S0=1.0): r"""Calculates the fODF for a given discrete sphere. Parameters ---------- gtab : GradientTable, optional gradient directions and bvalues container class. S0 : float, optional the signal at b-value=0 """ if gtab is None: gtab = self.gtab M_diff = forecast_matrix(self.sh_order_max, self.d_par, self.d_perp, gtab.bvals) rho = rho_matrix(self.sh_order_max, gtab.bvecs) M = M_diff * rho S = S0 * np.dot(M, self._sh_coef) return S
@property def sh_coeff(self): """The FORECAST SH coefficients""" return self._sh_coef @property def dpar(self): """The parallel diffusivity""" return self.d_par @property def dperp(self): """The perpendicular diffusivity""" return self.d_perp
[docs] @warning_for_keywords() def find_signal_means(b_unique, data_norm, bvals, rho, lb_matrix, *, w=1e-03): r"""Calculate the mean signal for each shell. Parameters ---------- b_unique : 1d ndarray, unique b-values in a vector excluding zero data_norm : 1d ndarray, normalized diffusion signal bvals : 1d ndarray, the b-values rho : 2d ndarray, SH basis matrix for fitting the signal on each shell lb_matrix : 2d ndarray, Laplace-Beltrami regularization matrix w : float, weight for the Laplace-Beltrami regularization Returns ------- means : 1d ndarray the average of the signal for each b-values """ lb = len(b_unique) means = np.zeros(lb) for u in range(lb): ind = bvals == b_unique[u] shell = data_norm[ind] if np.sum(ind) > 20: M = rho[ind, :] coef = np.linalg.multi_dot( [np.linalg.inv(np.dot(M.T, M) + w * lb_matrix), M.T, shell] ) means[u] = coef[0] / np.sqrt(4 * np.pi) else: means[u] = shell.mean() return means
[docs] def forecast_error_func(x, b_unique, E): r"""Calculates the difference between the mean signal calculated using the parameter vector x and the average signal E using FORECAST and SMT """ d_par = np.cos(x[0]) ** 2 * 3e-03 d_perp = np.cos(x[1]) ** 2 * 3e-03 if d_perp >= d_par: d_par, d_perp = d_perp, d_par E_reconst = ( 0.5 * np.exp(-b_unique * d_perp) * psi_l(0, (b_unique * (d_par - d_perp))) ) v = E - E_reconst return v
[docs] def psi_l(ell, b): n = ell // 2 v = (-b) ** n v *= gamma(n + 1.0 / 2) / gamma(2 * n + 3.0 / 2) v *= hyp1f1(n + 1.0 / 2, 2 * n + 3.0 / 2, -b) return v
[docs] @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") def forecast_matrix(sh_order_max, d_par, d_perp, bvals): r"""Compute the FORECAST radial matrix""" n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2) M = np.zeros((bvals.shape[0], n_c)) counter = 0 for ell in range(0, sh_order_max + 1, 2): for _ in range(-ell, ell + 1): M[:, counter] = ( 2 * np.pi * np.exp(-bvals * d_perp) * psi_l(ell, bvals * (d_par - d_perp)) ) counter += 1 return M
[docs] @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") def rho_matrix(sh_order_max, vecs): r"""Compute the SH matrix $\rho$""" r, theta, phi = cart2sphere(vecs[:, 0], vecs[:, 1], vecs[:, 2]) theta[np.isnan(theta)] = 0 n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2) rho = np.zeros((vecs.shape[0], n_c)) counter = 0 for l_values in range(0, sh_order_max + 1, 2): for m_values in range(-l_values, l_values + 1): rho[:, counter] = real_sh_descoteaux_from_index( m_values, l_values, theta, phi ) counter += 1 return rho
[docs] @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") def lb_forecast(sh_order_max): r"""Returns the Laplace-Beltrami regularization matrix for FORECAST""" n_c = int((sh_order_max + 1) * (sh_order_max + 2) / 2) diag_lb = np.zeros(n_c) counter = 0 for j in range(0, sh_order_max + 1, 2): stop = 2 * j + 1 + counter diag_lb[counter:stop] = (j * (j + 1)) ** 2 counter = stop return np.diag(diag_lb)