Source code for dipy.reconst.mcsd

import numbers
import warnings

import numpy as np

from dipy.core import geometry as geo
from dipy.core.gradients import (
    GradientTable,
    get_bval_indices,
    gradient_table,
    unique_bvals_tolerance,
)
from dipy.data import default_sphere
from dipy.reconst import shm
from dipy.reconst.csdeconv import response_from_mask_ssst
from dipy.reconst.dti import TensorModel, fractional_anisotropy, mean_diffusivity
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.utils import _mask_from_roi, _roi_in_volume
from dipy.sims.voxel import single_tensor
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package

cvxpy, have_cvxpy, _ = optional_package("cvxpy", min_version="1.4.1")

SH_CONST = 0.5 / np.sqrt(np.pi)


[docs] @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") def multi_tissue_basis(gtab, sh_order_max, iso_comp): """ Builds a basis for multi-shell multi-tissue CSD model. Parameters ---------- gtab : GradientTable Gradient table. sh_order_max : int Maximal spherical harmonics order (l). iso_comp: int Number of tissue compartments for running the MSMT-CSD. Minimum number of compartments required is 2. Returns ------- B : ndarray Matrix of the spherical harmonics model used to fit the data m_values : int ``|m_value| <= l_value`` The phase factor ($m$) of the harmonic. l_values : int ``l_value >= 0`` The order ($l$) of the harmonic. """ if iso_comp < 2: msg = "Multi-tissue CSD requires at least 2 tissue compartments" raise ValueError(msg) r, theta, phi = geo.cart2sphere(*gtab.gradients.T) m_values, l_values = shm.sph_harm_ind_list(sh_order_max) B = shm.real_sh_descoteaux_from_index( m_values, l_values, theta[:, None], phi[:, None] ) B[np.ix_(gtab.b0s_mask, l_values > 0)] = 0.0 iso = np.empty([B.shape[0], iso_comp]) iso[:] = SH_CONST B = np.concatenate([iso, B], axis=1) return B, m_values, l_values
[docs] class MultiShellResponse: @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") @warning_for_keywords() def __init__(self, response, sh_order_max, shells, *, S0=None): """Estimate Multi Shell response function for multiple tissues and multiple shells. The method `multi_shell_fiber_response` allows to create a multi-shell fiber response with the right format, for a three compartments model. It can be referred to in order to understand the inputs of this class. Parameters ---------- response : ndarray Multi-shell fiber response. The ordering of the responses should follow the same logic as S0. sh_order_max : int Maximal spherical harmonics order (l). shells : int Number of shells in the data S0 : array (3,) Signal with no diffusion weighting for each tissue compartments, in the same tissue order as `response`. This S0 can be used for predicting from a fit model later on. """ self.S0 = S0 self.response = response self.sh_order_max = sh_order_max self.l_values = np.arange(0, sh_order_max + 1, 2) self.m_values = np.zeros_like(self.l_values) self.shells = shells if self.iso < 1: raise ValueError("sh_order_max and shape of response do not agree") @property def iso(self): return self.response.shape[1] - (self.sh_order_max // 2) - 1
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") def _inflate_response(response, gtab, sh_order_max, delta): """Used to inflate the response for the `multiplier_matrix` in the `MultiShellDeconvModel`. Parameters ---------- response : MultiShellResponse object Response function. gtab : GradientTable Gradient table. sh_order_max : int ``>= 0`` The maximal order ($l$) of the harmonic. delta : Delta generated from `_basic_delta` """ if ( any((sh_order_max % 2) != 0) or (sh_order_max.max() // 2) >= response.sh_order_max ): raise ValueError("Response and n do not match") iso = response.iso n_idx = np.empty(len(sh_order_max) + iso, dtype=int) n_idx[:iso] = np.arange(0, iso) n_idx[iso:] = sh_order_max // 2 + iso diff = abs(response.shells[:, None] - gtab.bvals) b_idx = np.argmin(diff, axis=0) kernel = response.response / delta return kernel[np.ix_(b_idx, n_idx)] def _basic_delta(iso, m_value, l_value, theta, phi): """Simple delta function Parameters ---------- iso: int Number of tissue compartments for running the MSMT-CSD. Minimum number of compartments required is 2. Default: 2 m_value : int ``|m| <= l`` The phase factor ($m$) of the harmonic. l_value : int ``>= 0`` The order ($l$) of the harmonic. theta : array_like inclination or polar angle phi : array_like azimuth angle """ wm_d = shm.gen_dirac(m_value, l_value, theta, phi) iso_d = [SH_CONST] * iso return np.concatenate([iso_d, wm_d])
[docs] class MultiShellDeconvModel(shm.SphHarmModel): @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") @warning_for_keywords() def __init__( self, gtab, response, reg_sphere=default_sphere, *, sh_order_max=8, iso=2, tol=20, ): r""" Multi-Shell Multi-Tissue Constrained Spherical Deconvolution (MSMT-CSD) :footcite:p:`Jeurissen2014`. This method extends the CSD model proposed in :footcite:p:`Tournier2007`. by the estimation of multiple response functions as a function of multiple b-values and multiple tissue types. Spherical deconvolution computes a fiber orientation distribution (FOD), also called fiber ODF (fODF) :footcite:p:`Tournier2007`. The fODF is derived from different tissue types and thus overcomes the overestimation of WM in GM and CSF areas. The response function is based on the different tissue types and is provided as input to the MultiShellDeconvModel. It will be used as deconvolution kernel, as described in :footcite:p:`Tournier2007`. Parameters ---------- gtab : GradientTable Gradient table. response : ndarray or MultiShellResponse object Pre-computed multi-shell fiber response function in the form of a MultiShellResponse object, or simple response function as a ndarray. The later must be of shape (3, len(bvals)-1, 4), because it will be converted into a MultiShellResponse object via the `multi_shell_fiber_response` method (important note: the function `unique_bvals_tolerance` is used here to select unique bvalues from gtab as input). Each column (3,) has two elements. The first is the eigen-values as a (3,) ndarray and the second is the signal value for the response function without diffusion weighting (S0). Note that in order to use more than three compartments, one must create a MultiShellResponse object on the side. reg_sphere : Sphere, optional sphere used to build the regularization B matrix. sh_order_max : int, optional Maximal spherical harmonics order (l). iso: int, optional Number of tissue compartments for running the MSMT-CSD. Minimum number of compartments required is 2. tol : int, optional Tolerance gap for b-values clustering. References ---------- .. footbibliography:: """ if not iso >= 2: msg = "Multi-tissue CSD requires at least 2 tissue compartments" raise ValueError(msg) super(MultiShellDeconvModel, self).__init__(gtab) if not isinstance(response, MultiShellResponse): bvals = unique_bvals_tolerance(gtab.bvals, tol=tol) if iso > 2: msg = """Too many compartments for this kind of response input. It must be two tissue compartments.""" raise ValueError(msg) if response.shape != (3, len(bvals) - 1, 4): msg = """Response must be of shape (3, len(bvals)-1, 4) or be a MultiShellResponse object.""" raise ValueError(msg) response = multi_shell_fiber_response( sh_order_max, bvals=bvals, wm_rf=response[0], gm_rf=response[1], csf_rf=response[2], ) B, m_values, l_values = multi_tissue_basis(gtab, sh_order_max, iso) delta = _basic_delta( response.iso, response.m_values, response.l_values, 0.0, 0.0 ) self.delta = delta multiplier_matrix = _inflate_response(response, gtab, l_values, delta) r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T) odf_reg, _, _ = shm.real_sh_descoteaux(sh_order_max, theta, phi) reg = np.zeros([i + iso for i in odf_reg.shape]) reg[:iso, :iso] = np.eye(iso) reg[iso:, iso:] = odf_reg X = B * multiplier_matrix self.fitter = QpFitter(X, reg) self.sh_order_max = sh_order_max self._X = X self.sphere = reg_sphere self.gtab = gtab self.B_dwi = B self.m_values = m_values self.l_values = l_values self.response = response
[docs] @warning_for_keywords() def predict(self, params, *, gtab=None, S0=None): """Compute a signal prediction given spherical harmonic coefficients for the provided GradientTable class instance. Parameters ---------- params : ndarray The spherical harmonic representation of the FOD from which to make the signal prediction. gtab : GradientTable The gradients for which the signal will be predicted. Use the model's gradient table by default. S0 : ndarray or float The non diffusion-weighted signal value. """ if gtab is None or gtab is self.gtab: X = self._X else: iso = self.response.iso B, m_values, l_values = multi_tissue_basis(gtab, self.sh_order_max, iso) multiplier_matrix = _inflate_response( self.response, gtab, l_values, self.delta ) X = B * multiplier_matrix scaling = 1.0 if S0 and S0 != 1.0: # The S0=1. case comes from fit.predict(). raise NotImplementedError # This case is not implemented yet because it would require to have # access to volume fractions (vf) from the fit. The following code # gives an idea of how to use this with S0 and vf. It could also be # calculated externally and used as scaling = S0. # response_scaling = np.ndarray(params.shape[0:3]) # response_scaling[...] = (vf[..., 0] * self.response.S0[0] # + vf[..., 1] * self.response.S0[1] # + vf[..., 2] * self.response.S0[2]) # scaling = np.where(response_scaling > 1, S0 / response_scaling, 0) # scaling = np.expand_dims(scaling, 3) # scaling = np.repeat(scaling, len(gtab.bvals), axis=3) pred_sig = scaling * np.dot(params, X.T) return pred_sig
@multi_voxel_fit def fit(self, data, verbose=True, **kwargs): """Fits the model to diffusion data and returns the model fit. Sometimes the solving process of some voxels can end in a SolverError from cvxpy. This might be attributed to the response functions not being tuned properly, as the solving process is very sensitive to it. The method will fill the problematic voxels with a NaN value, so that it is traceable. The user should check for the number of NaN values and could then fill the problematic voxels with zeros, for example. Running a fit again only on those problematic voxels can also work. Parameters ---------- data : ndarray The diffusion data to fit the model on. verbose : bool, optional Whether to show warnings when a SolverError appears or not. """ coeff = self.fitter(data) if verbose: if np.isnan(coeff[..., 0]): msg = """Voxel could not be solved properly and ended up with a SolverError. Proceeding to fill it with NaN values. """ warnings.warn(msg, UserWarning, stacklevel=2) return MSDeconvFit(self, coeff, None)
[docs] class MSDeconvFit(shm.SphHarmFit): def __init__(self, model, coeff, mask): """ Abstract class which holds the fit result of MultiShellDeconvModel. Inherits the SphHarmFit which fits the diffusion data to a spherical harmonic model. Parameters ---------- model: object MultiShellDeconvModel coeff : array Spherical harmonic coefficients for the ODF. mask: ndarray Mask for fitting """ self._shm_coef = coeff self.mask = mask self.model = model @property def shm_coeff(self): return self._shm_coef[..., self.model.response.iso :] @property def all_shm_coeff(self): return self._shm_coef @property def volume_fractions(self): tissue_classes = self.model.response.iso + 1 return self._shm_coef[..., :tissue_classes] / SH_CONST
[docs] def solve_qp(P, Q, G, H): r""" Helper function to set up and solve the Quadratic Program (QP) in CVXPY. A QP problem has the following form: minimize 1/2 x' P x + Q' x subject to G x <= H Here the QP solver is based on CVXPY and uses OSQP. Parameters ---------- P : ndarray n x n matrix for the primal QP objective function. Q : ndarray n x 1 matrix for the primal QP objective function. G : ndarray m x n matrix for the inequality constraint. H : ndarray m x 1 matrix for the inequality constraint. Returns ------- x : array Optimal solution to the QP problem. """ x = cvxpy.Variable(Q.shape[0]) P = cvxpy.Constant(P) objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P, True) + Q @ x) constraints = [G @ x <= H] # setting up the problem prob = cvxpy.Problem(objective, constraints) try: prob.solve() opt = np.array(x.value).reshape((Q.shape[0],)) except cvxpy.error.SolverError: opt = np.empty((Q.shape[0],)) opt[:] = np.nan return opt
[docs] class QpFitter: def __init__(self, X, reg): r""" Makes use of the quadratic programming solver `solve_qp` to fit the model. The initialization for the model is done using the warm-start by default in `CVXPY`. Parameters ---------- X : ndarray Matrix to be fit by the QP solver calculated in `MultiShellDeconvModel` reg : ndarray the regularization B matrix calculated in `MultiShellDeconvModel` """ self._P = P = np.dot(X.T, X) self._X = X self._reg = reg self._P_mat = np.array(P) self._reg_mat = np.array(-reg) self._h_mat = np.array([0]) def __call__(self, signal): Q = np.dot(self._X.T, signal) Q_mat = np.array(-Q) fodf_sh = solve_qp(self._P_mat, Q_mat, self._reg_mat, self._h_mat) return fodf_sh
[docs] @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0") @warning_for_keywords() def multi_shell_fiber_response( sh_order_max, bvals, wm_rf, gm_rf, csf_rf, *, sphere=None, tol=20, btens=None ): """Fiber response function estimation for multi-shell data. Parameters ---------- sh_order_max : int Maximum spherical harmonics order (l). bvals : ndarray Array containing the b-values. Must be unique b-values, like outputted by `dipy.core.gradients.unique_bvals_tolerance`. wm_rf : (N-1, 4) ndarray Response function of the WM tissue, for each bvals, where N is the number of unique b-values including the b0. gm_rf : (N-1, 4) ndarray Response function of the GM tissue, for each bvals. csf_rf : (N-1, 4) ndarray Response function of the CSF tissue, for each bvals. sphere : `dipy.core.Sphere` instance, optional Sphere where the signal will be evaluated. tol : int, optional Tolerance gap for b-values clustering. btens : can be any of two options, optional 1. an array of strings of shape (N,) specifying encoding tensor shape associated with all unique b-values separately. N corresponds to the number of unique b-values, including the b0. Options for elements in array: 'LTE', 'PTE', 'STE', 'CTE' corresponding to linear, planar, spherical, and "cigar-shaped" tensor encoding. 2. an array of shape (N,3,3) specifying the b-tensor of each unique b-values exactly. N corresponds to the number of unique b-values, including the b0. Returns ------- MultiShellResponse MultiShellResponse object. """ bvals = np.array(bvals, copy=True) if btens is None: btens = np.repeat(["LTE"], len(bvals)) elif len(btens) != len(bvals): msg = """bvals and btens parameters must have the same dimension.""" raise ValueError(msg) evecs = np.zeros((3, 3)) z = np.array([0, 0, 1.0]) evecs[:, 0] = z evecs[:2, 1:] = np.eye(2) l_values = np.arange(0, sh_order_max + 1, 2) m_values = np.zeros_like(l_values) if sphere is None: sphere = default_sphere big_sphere = sphere.subdivide() theta, phi = big_sphere.theta, big_sphere.phi B = shm.real_sh_descoteaux_from_index( m_values, l_values, theta[:, None], phi[:, None] ) A = shm.real_sh_descoteaux_from_index(0, 0, 0, 0) response = np.empty([len(bvals), len(l_values) + 2]) if bvals[0] < tol: gtab = GradientTable(big_sphere.vertices * 0, btens=btens[0]) wm_response = single_tensor( gtab, wm_rf[0, 3], evals=wm_rf[0, :3], evecs=evecs, snr=None ) response[0, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0] response[0, 1] = gm_rf[0, 3] / A response[0, 0] = csf_rf[0, 3] / A for i, bvalue in enumerate(bvals[1:]): gtab = GradientTable(big_sphere.vertices * bvalue, btens=btens[i + 1]) wm_response = single_tensor( gtab, wm_rf[i, 3], evals=wm_rf[i, :3], evecs=evecs, snr=None ) response[i + 1, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0] response[i + 1, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A response[i + 1, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]] else: warnings.warn( """No b0 given. Proceeding either way.""", UserWarning, stacklevel=2 ) for i, bvalue in enumerate(bvals): gtab = GradientTable(big_sphere.vertices * bvalue, btens=btens[i]) wm_response = single_tensor( gtab, wm_rf[i, 3], evals=wm_rf[i, :3], evecs=evecs, snr=None ) response[i, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0] response[i, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A response[i, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]] return MultiShellResponse(response, sh_order_max, bvals, S0=S0)
[docs] @warning_for_keywords() def mask_for_response_msmt( gtab, data, *, roi_center=None, roi_radii=10, wm_fa_thr=0.7, gm_fa_thr=0.2, csf_fa_thr=0.1, gm_md_thr=0.0007, csf_md_thr=0.002, ): """Computation of masks for multi-shell multi-tissue (msmt) response function using FA and MD. Parameters ---------- gtab : GradientTable Gradient table. data : ndarray diffusion data (4D) roi_center : array-like, (3,) Center of ROI in data. If center is None, it is assumed that it is the center of the volume with shape `data.shape[:3]`. roi_radii : int or array-like, (3,) radii of cuboid ROI wm_fa_thr : float FA threshold for WM. gm_fa_thr : float FA threshold for GM. csf_fa_thr : float FA threshold for CSF. gm_md_thr : float MD threshold for GM. csf_md_thr : float MD threshold for CSF. Returns ------- mask_wm : ndarray Mask of voxels within the ROI and with FA above the FA threshold for WM. mask_gm : ndarray Mask of voxels within the ROI and with FA below the FA threshold for GM and with MD below the MD threshold for GM. mask_csf : ndarray Mask of voxels within the ROI and with FA below the FA threshold for CSF and with MD below the MD threshold for CSF. Notes ----- In msmt-CSD there is an important pre-processing step: the estimation of every tissue's response function. In order to do this, we look for voxels corresponding to WM, GM and CSF. This function aims to accomplish that by returning a mask of voxels within a ROI and who respect some threshold constraints, for each tissue. More precisely, the WM mask must have a FA value above a given threshold. The GM mask and CSF mask must have a FA below given thresholds and a MD below other thresholds. To get the FA and MD, we need to fit a Tensor model to the datasets. """ if len(data.shape) < 4: msg = """Data must be 4D (3D image + directions). To use a 2D image, please reshape it into a (N, N, 1, ndirs) array.""" raise ValueError(msg) if isinstance(roi_radii, numbers.Number): roi_radii = (roi_radii, roi_radii, roi_radii) if roi_center is None: roi_center = np.array(data.shape[:3]) // 2 roi_radii = _roi_in_volume( data.shape, np.asarray(roi_center), np.asarray(roi_radii) ) roi_mask = _mask_from_roi(data.shape[:3], roi_center, roi_radii) list_bvals = unique_bvals_tolerance(gtab.bvals) if not np.all(list_bvals <= 1200): msg_bvals = """Some b-values are higher than 1200. The DTI fit might be affected.""" warnings.warn(msg_bvals, UserWarning, stacklevel=2) ten = TensorModel(gtab) tenfit = ten.fit(data, mask=roi_mask) fa = fractional_anisotropy(tenfit.evals) fa[np.isnan(fa)] = 0 md = mean_diffusivity(tenfit.evals) md[np.isnan(md)] = 0 mask_wm = np.zeros(fa.shape, dtype=np.int64) mask_wm[fa > wm_fa_thr] = 1 mask_wm *= roi_mask md_mask_gm = np.zeros(md.shape, dtype=np.int64) md_mask_gm[(md < gm_md_thr)] = 1 fa_mask_gm = np.zeros(fa.shape, dtype=np.int64) fa_mask_gm[(fa < gm_fa_thr) & (fa > 0)] = 1 mask_gm = md_mask_gm * fa_mask_gm mask_gm *= roi_mask md_mask_csf = np.zeros(md.shape, dtype=np.int64) md_mask_csf[(md < csf_md_thr) & (md > 0)] = 1 fa_mask_csf = np.zeros(fa.shape, dtype=np.int64) fa_mask_csf[(fa < csf_fa_thr) & (fa > 0)] = 1 mask_csf = md_mask_csf * fa_mask_csf mask_csf *= roi_mask msg = """No voxel with a {0} than {1} were found. Try a larger roi or a {2} threshold for {3}.""" if np.sum(mask_wm) == 0: msg_fa = msg.format("FA higher", str(wm_fa_thr), "lower FA", "WM") warnings.warn(msg_fa, UserWarning, stacklevel=2) if np.sum(mask_gm) == 0: msg_fa = msg.format("FA lower", str(gm_fa_thr), "higher FA", "GM") msg_md = msg.format("MD lower", str(gm_md_thr), "higher MD", "GM") warnings.warn(msg_fa, UserWarning, stacklevel=2) warnings.warn(msg_md, UserWarning, stacklevel=2) if np.sum(mask_csf) == 0: msg_fa = msg.format("FA lower", str(csf_fa_thr), "higher FA", "CSF") msg_md = msg.format("MD lower", str(csf_md_thr), "higher MD", "CSF") warnings.warn(msg_fa, UserWarning, stacklevel=2) warnings.warn(msg_md, UserWarning, stacklevel=2) return mask_wm, mask_gm, mask_csf
[docs] @warning_for_keywords() def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, *, tol=20): """Computation of multi-shell multi-tissue (msmt) response functions from given tissues masks. Parameters ---------- gtab : GradientTable Gradient table. data : ndarray diffusion data mask_wm : ndarray mask from where to compute the WM response function. mask_gm : ndarray mask from where to compute the GM response function. mask_csf : ndarray mask from where to compute the CSF response function. tol : int tolerance gap for b-values clustering. (Default = 20) Returns ------- response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for WM for each unique bvalues (except b0). response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for GM for each unique bvalues (except b0). response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for CSF for each unique bvalues (except b0). Notes ----- In msmt-CSD there is an important pre-processing step: the estimation of every tissue's response function. In order to do this, we look for voxels corresponding to WM, GM and CSF. This information can be obtained by using mcsd.mask_for_response_msmt() through masks of selected voxels. The present function uses such masks to compute the msmt response functions. For the responses, we base our approach on the function csdeconv.response_from_mask_ssst(), with the added layers of multishell and multi-tissue (see the ssst function for more information about the computation of the ssst response function). This means that for each tissue we use the previously found masks and loop on them. For each mask, we loop on the b-values (clustered using the tolerance gap) to get many responses and then average them to get one response per tissue. """ bvals = gtab.bvals bvecs = gtab.bvecs btens = gtab.btens list_bvals = unique_bvals_tolerance(bvals, tol=tol) b0_indices = get_bval_indices(bvals, list_bvals[0], tol=tol) b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis] masks = [mask_wm, mask_gm, mask_csf] tissue_responses = [] for mask in masks: responses = [] for bval in list_bvals[1:]: indices = get_bval_indices(bvals, bval, tol=tol) bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]], bvecs[indices]]) bvals_sub = np.concatenate([[0], bvals[indices]]) if btens is not None: btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3)) btens_sub = np.concatenate([btens_b0, btens[indices]]) else: btens_sub = None data_conc = np.concatenate([b0_map, data[..., indices]], axis=3) gtab = gradient_table(bvals_sub, bvecs=bvecs_sub, btens=btens_sub) response, _ = response_from_mask_ssst(gtab, data_conc, mask) responses.append(list(np.concatenate([response[0], [response[1]]]))) tissue_responses.append(list(responses)) wm_response = np.asarray(tissue_responses[0]) gm_response = np.asarray(tissue_responses[1]) csf_response = np.asarray(tissue_responses[2]) return wm_response, gm_response, csf_response
[docs] @warning_for_keywords() def auto_response_msmt( gtab, data, *, tol=20, roi_center=None, roi_radii=10, wm_fa_thr=0.7, gm_fa_thr=0.3, csf_fa_thr=0.15, gm_md_thr=0.001, csf_md_thr=0.0032, ): """Automatic estimation of multi-shell multi-tissue (msmt) response functions using FA and MD. Parameters ---------- gtab : GradientTable Gradient table. data : ndarray diffusion data tol : int, optional Tolerance gap for b-values clustering. roi_center : array-like, (3,) Center of ROI in data. If center is None, it is assumed that it is the center of the volume with shape `data.shape[:3]`. roi_radii : int or array-like, (3,) radii of cuboid ROI wm_fa_thr : float FA threshold for WM. gm_fa_thr : float FA threshold for GM. csf_fa_thr : float FA threshold for CSF. gm_md_thr : float MD threshold for GM. csf_md_thr : float MD threshold for CSF. Returns ------- response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for WM for each unique bvalues (except b0). response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for GM for each unique bvalues (except b0). response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4) (`evals`, `S0`) for CSF for each unique bvalues (except b0). Notes ----- In msmt-CSD there is an important pre-processing step: the estimation of every tissue's response function. In order to do this, we look for voxels corresponding to WM, GM and CSF. We get this information from mcsd.mask_for_response_msmt(), which returns masks of selected voxels (more details are available in the description of the function). With the masks, we compute the response functions by using mcsd.response_from_mask_msmt(), which returns the `response` for each tissue (more details are available in the description of the function). """ list_bvals = unique_bvals_tolerance(gtab.bvals) if not np.all(list_bvals <= 1200): msg_bvals = """Some b-values are higher than 1200. The DTI fit might be affected. It is advised to use mask_for_response_msmt with bvalues lower than 1200, followed by response_from_mask_msmt with all bvalues to overcome this.""" warnings.warn(msg_bvals, UserWarning, stacklevel=2) mask_wm, mask_gm, mask_csf = mask_for_response_msmt( gtab, data, roi_center=roi_center, roi_radii=roi_radii, wm_fa_thr=wm_fa_thr, gm_fa_thr=gm_fa_thr, csf_fa_thr=csf_fa_thr, gm_md_thr=gm_md_thr, csf_md_thr=csf_md_thr, ) response_wm, response_gm, response_csf = response_from_mask_msmt( gtab, data, mask_wm, mask_gm, mask_csf, tol=tol ) return response_wm, response_gm, response_csf