import numbers
import warnings
import numpy as np
from dipy.core import geometry as geo
from dipy.core.gradients import (
    GradientTable,
    get_bval_indices,
    gradient_table,
    unique_bvals_tolerance,
)
from dipy.data import default_sphere
from dipy.reconst import shm
from dipy.reconst.csdeconv import response_from_mask_ssst
from dipy.reconst.dti import TensorModel, fractional_anisotropy, mean_diffusivity
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.utils import _mask_from_roi, _roi_in_volume
from dipy.sims.voxel import single_tensor
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package
cvxpy, have_cvxpy, _ = optional_package("cvxpy", min_version="1.4.1")
SH_CONST = 0.5 / np.sqrt(np.pi)
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
def multi_tissue_basis(gtab, sh_order_max, iso_comp):
    """
    Builds a basis for multi-shell multi-tissue CSD model.
    Parameters
    ----------
    gtab : GradientTable
        Gradient table.
    sh_order_max : int
        Maximal spherical harmonics order (l).
    iso_comp: int
        Number of tissue compartments for running the MSMT-CSD. Minimum
        number of compartments required is 2.
    Returns
    -------
    B : ndarray
        Matrix of the spherical harmonics model used to fit the data
    m_values : int ``|m_value| <= l_value``
        The phase factor ($m$) of the harmonic.
    l_values : int ``l_value >= 0``
        The order ($l$) of the harmonic.
    """
    if iso_comp < 2:
        msg = "Multi-tissue CSD requires at least 2 tissue compartments"
        raise ValueError(msg)
    r, theta, phi = geo.cart2sphere(*gtab.gradients.T)
    m_values, l_values = shm.sph_harm_ind_list(sh_order_max)
    B = shm.real_sh_descoteaux_from_index(
        m_values, l_values, theta[:, None], phi[:, None]
    )
    B[np.ix_(gtab.b0s_mask, l_values > 0)] = 0.0
    iso = np.empty([B.shape[0], iso_comp])
    iso[:] = SH_CONST
    B = np.concatenate([iso, B], axis=1)
    return B, m_values, l_values 
[docs]
class MultiShellResponse:
    @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
    @warning_for_keywords()
    def __init__(self, response, sh_order_max, shells, *, S0=None):
        """Estimate Multi Shell response function for multiple tissues and
        multiple shells.
        The method `multi_shell_fiber_response` allows to create a multi-shell
        fiber response with the right format, for a three compartments model.
        It can be referred to in order to understand the inputs of this class.
        Parameters
        ----------
        response : ndarray
            Multi-shell fiber response. The ordering of the responses should
            follow the same logic as S0.
        sh_order_max : int
            Maximal spherical harmonics order (l).
        shells : int
            Number of shells in the data
        S0 : array (3,)
            Signal with no diffusion weighting for each tissue compartments, in
            the same tissue order as `response`. This S0 can be used for
            predicting from a fit model later on.
        """
        self.S0 = S0
        self.response = response
        self.sh_order_max = sh_order_max
        self.l_values = np.arange(0, sh_order_max + 1, 2)
        self.m_values = np.zeros_like(self.l_values)
        self.shells = shells
        if self.iso < 1:
            raise ValueError("sh_order_max and shape of response do not agree")
    @property
    def iso(self):
        return self.response.shape[1] - (self.sh_order_max // 2) - 1 
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
def _inflate_response(response, gtab, sh_order_max, delta):
    """Used to inflate the response for the `multiplier_matrix` in the
    `MultiShellDeconvModel`.
    Parameters
    ----------
    response : MultiShellResponse object
        Response function.
    gtab : GradientTable
        Gradient table.
    sh_order_max : int ``>= 0``
        The maximal order ($l$) of the harmonic.
    delta : Delta generated from `_basic_delta`
    """
    if (
        any((sh_order_max % 2) != 0)
        or (sh_order_max.max() // 2) >= response.sh_order_max
    ):
        raise ValueError("Response and n do not match")
    iso = response.iso
    n_idx = np.empty(len(sh_order_max) + iso, dtype=int)
    n_idx[:iso] = np.arange(0, iso)
    n_idx[iso:] = sh_order_max // 2 + iso
    diff = abs(response.shells[:, None] - gtab.bvals)
    b_idx = np.argmin(diff, axis=0)
    kernel = response.response / delta
    return kernel[np.ix_(b_idx, n_idx)]
def _basic_delta(iso, m_value, l_value, theta, phi):
    """Simple delta function
    Parameters
    ----------
    iso: int
        Number of tissue compartments for running the MSMT-CSD. Minimum
        number of compartments required is 2.
        Default: 2
    m_value : int ``|m| <= l``
        The phase factor ($m$) of the harmonic.
    l_value : int ``>= 0``
        The order ($l$) of the harmonic.
    theta : array_like
       inclination or polar angle
    phi : array_like
       azimuth angle
    """
    wm_d = shm.gen_dirac(m_value, l_value, theta, phi)
    iso_d = [SH_CONST] * iso
    return np.concatenate([iso_d, wm_d])
[docs]
class MultiShellDeconvModel(shm.SphHarmModel):
    @deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
    @warning_for_keywords()
    def __init__(
        self,
        gtab,
        response,
        reg_sphere=default_sphere,
        *,
        sh_order_max=8,
        iso=2,
        tol=20,
    ):
        r"""
        Multi-Shell Multi-Tissue Constrained Spherical Deconvolution
        (MSMT-CSD) :footcite:p:`Jeurissen2014`. This method extends the CSD
        model proposed in :footcite:p:`Tournier2007`. by the estimation of
        multiple response functions as a function of multiple b-values and
        multiple tissue types.
        Spherical deconvolution computes a fiber orientation distribution
        (FOD), also called fiber ODF (fODF) :footcite:p:`Tournier2007`. The fODF
        is derived from different tissue types and thus overcomes the
        overestimation of WM in GM and CSF areas.
        The response function is based on the different tissue types
        and is provided as input to the MultiShellDeconvModel.
        It will be used as deconvolution kernel, as described in
        :footcite:p:`Tournier2007`.
        Parameters
        ----------
        gtab : GradientTable
            Gradient table.
        response : ndarray or MultiShellResponse object
            Pre-computed multi-shell fiber response function in the form of a
            MultiShellResponse object, or simple response function as a ndarray.
            The later must be of shape (3, len(bvals)-1, 4), because it will be
            converted into a MultiShellResponse object via the
            `multi_shell_fiber_response` method (important note: the function
            `unique_bvals_tolerance` is used here to select unique bvalues from
            gtab as input). Each column (3,) has two elements. The first is the
            eigen-values as a (3,) ndarray and the second is the signal value
            for the response function without diffusion weighting (S0). Note
            that in order to use more than three compartments, one must create
            a MultiShellResponse object on the side.
        reg_sphere : Sphere, optional
            sphere used to build the regularization B matrix.
        sh_order_max : int, optional
            Maximal spherical harmonics order (l).
        iso: int, optional
            Number of tissue compartments for running the MSMT-CSD. Minimum
            number of compartments required is 2.
        tol : int, optional
            Tolerance gap for b-values clustering.
        References
        ----------
        .. footbibliography::
        """
        if not iso >= 2:
            msg = "Multi-tissue CSD requires at least 2 tissue compartments"
            raise ValueError(msg)
        super(MultiShellDeconvModel, self).__init__(gtab)
        if not isinstance(response, MultiShellResponse):
            bvals = unique_bvals_tolerance(gtab.bvals, tol=tol)
            if iso > 2:
                msg = """Too many compartments for this kind of response
                input. It must be two tissue compartments."""
                raise ValueError(msg)
            if response.shape != (3, len(bvals) - 1, 4):
                msg = """Response must be of shape (3, len(bvals)-1, 4) or be a
                MultiShellResponse object."""
                raise ValueError(msg)
            response = multi_shell_fiber_response(
                sh_order_max,
                bvals=bvals,
                wm_rf=response[0],
                gm_rf=response[1],
                csf_rf=response[2],
            )
        B, m_values, l_values = multi_tissue_basis(gtab, sh_order_max, iso)
        delta = _basic_delta(
            response.iso, response.m_values, response.l_values, 0.0, 0.0
        )
        self.delta = delta
        multiplier_matrix = _inflate_response(response, gtab, l_values, delta)
        r, theta, phi = geo.cart2sphere(*reg_sphere.vertices.T)
        odf_reg, _, _ = shm.real_sh_descoteaux(sh_order_max, theta, phi)
        reg = np.zeros([i + iso for i in odf_reg.shape])
        reg[:iso, :iso] = np.eye(iso)
        reg[iso:, iso:] = odf_reg
        X = B * multiplier_matrix
        self.fitter = QpFitter(X, reg)
        self.sh_order_max = sh_order_max
        self._X = X
        self.sphere = reg_sphere
        self.gtab = gtab
        self.B_dwi = B
        self.m_values = m_values
        self.l_values = l_values
        self.response = response
[docs]
    @warning_for_keywords()
    def predict(self, params, *, gtab=None, S0=None):
        """Compute a signal prediction given spherical harmonic coefficients
        for the provided GradientTable class instance.
        Parameters
        ----------
        params : ndarray
            The spherical harmonic representation of the FOD from which to make
            the signal prediction.
        gtab : GradientTable
            The gradients for which the signal will be predicted. Use the
            model's gradient table by default.
        S0 : ndarray or float
            The non diffusion-weighted signal value.
        """
        if gtab is None or gtab is self.gtab:
            X = self._X
        else:
            iso = self.response.iso
            B, m_values, l_values = multi_tissue_basis(gtab, self.sh_order_max, iso)
            multiplier_matrix = _inflate_response(
                self.response, gtab, l_values, self.delta
            )
            X = B * multiplier_matrix
        scaling = 1.0
        if S0 and S0 != 1.0:  # The S0=1. case comes from fit.predict().
            raise NotImplementedError
            # This case is not implemented yet because it would require to have
            # access to volume fractions (vf) from the fit. The following code
            # gives an idea of how to use this with S0 and vf. It could also be
            # calculated externally and used as scaling = S0.
            # response_scaling = np.ndarray(params.shape[0:3])
            # response_scaling[...] = (vf[..., 0] * self.response.S0[0]
            #                          + vf[..., 1] * self.response.S0[1]
            #                          + vf[..., 2] * self.response.S0[2])
            # scaling = np.where(response_scaling > 1, S0 / response_scaling, 0)
            # scaling = np.expand_dims(scaling, 3)
            # scaling = np.repeat(scaling, len(gtab.bvals), axis=3)
        pred_sig = scaling * np.dot(params, X.T)
        return pred_sig 
    @multi_voxel_fit
    def fit(self, data, verbose=True, **kwargs):
        """Fits the model to diffusion data and returns the model fit.
        Sometimes the solving process of some voxels can end in a SolverError
        from cvxpy. This might be attributed to the response functions not
        being tuned properly, as the solving process is very sensitive to it.
        The method will fill the problematic voxels with a NaN value, so that
        it is traceable. The user should check for the number of NaN values and
        could then fill the problematic voxels with zeros, for example.
        Running a fit again only on those problematic voxels can also work.
        Parameters
        ----------
        data : ndarray
            The diffusion data to fit the model on.
        verbose : bool, optional
            Whether to show warnings when a SolverError appears or not.
        """
        coeff = self.fitter(data)
        if verbose:
            if np.isnan(coeff[..., 0]):
                msg = """Voxel could not be solved properly and ended up with a
                SolverError. Proceeding to fill it with NaN values.
                """
                warnings.warn(msg, UserWarning, stacklevel=2)
        return MSDeconvFit(self, coeff, None) 
[docs]
class MSDeconvFit(shm.SphHarmFit):
    def __init__(self, model, coeff, mask):
        """
        Abstract class which holds the fit result of MultiShellDeconvModel.
        Inherits the SphHarmFit which fits the diffusion data to a spherical
        harmonic model.
        Parameters
        ----------
        model: object
            MultiShellDeconvModel
        coeff : array
            Spherical harmonic coefficients for the ODF.
        mask: ndarray
            Mask for fitting
        """
        self._shm_coef = coeff
        self.mask = mask
        self.model = model
    @property
    def shm_coeff(self):
        return self._shm_coef[..., self.model.response.iso :]
    @property
    def all_shm_coeff(self):
        return self._shm_coef
    @property
    def volume_fractions(self):
        tissue_classes = self.model.response.iso + 1
        return self._shm_coef[..., :tissue_classes] / SH_CONST 
[docs]
def solve_qp(P, Q, G, H):
    r"""
    Helper function to set up and solve the Quadratic Program (QP) in CVXPY.
    A QP problem has the following form:
    minimize      1/2 x' P x + Q' x
    subject to    G x <= H
    Here the QP solver is based on CVXPY and uses OSQP.
    Parameters
    ----------
    P : ndarray
        n x n matrix for the primal QP objective function.
    Q : ndarray
        n x 1 matrix for the primal QP objective function.
    G : ndarray
        m x n matrix for the inequality constraint.
    H : ndarray
        m x 1 matrix for the inequality constraint.
    Returns
    -------
    x : array
        Optimal solution to the QP problem.
    """
    x = cvxpy.Variable(Q.shape[0])
    P = cvxpy.Constant(P)
    objective = cvxpy.Minimize(0.5 * cvxpy.quad_form(x, P, True) + Q @ x)
    constraints = [G @ x <= H]
    # setting up the problem
    prob = cvxpy.Problem(objective, constraints)
    try:
        prob.solve()
        opt = np.array(x.value).reshape((Q.shape[0],))
    except cvxpy.error.SolverError:
        opt = np.empty((Q.shape[0],))
        opt[:] = np.nan
    return opt 
[docs]
class QpFitter:
    def __init__(self, X, reg):
        r"""
        Makes use of the quadratic programming solver `solve_qp` to fit the
        model. The initialization for the model is done using the warm-start by
        default in `CVXPY`.
        Parameters
        ----------
        X : ndarray
            Matrix to be fit by the QP solver calculated in
            `MultiShellDeconvModel`
        reg : ndarray
            the regularization B matrix calculated in `MultiShellDeconvModel`
        """
        self._P = P = np.dot(X.T, X)
        self._X = X
        self._reg = reg
        self._P_mat = np.array(P)
        self._reg_mat = np.array(-reg)
        self._h_mat = np.array([0])
    def __call__(self, signal):
        Q = np.dot(self._X.T, signal)
        Q_mat = np.array(-Q)
        fodf_sh = solve_qp(self._P_mat, Q_mat, self._reg_mat, self._h_mat)
        return fodf_sh 
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def multi_shell_fiber_response(
    sh_order_max, bvals, wm_rf, gm_rf, csf_rf, *, sphere=None, tol=20, btens=None
):
    """Fiber response function estimation for multi-shell data.
    Parameters
    ----------
    sh_order_max : int
         Maximum spherical harmonics order (l).
    bvals : ndarray
        Array containing the b-values. Must be unique b-values, like outputted
        by `dipy.core.gradients.unique_bvals_tolerance`.
    wm_rf : (N-1, 4) ndarray
        Response function of the WM tissue, for each bvals,
        where N is the number of unique b-values including the b0.
    gm_rf : (N-1, 4) ndarray
        Response function of the GM tissue, for each bvals.
    csf_rf : (N-1, 4) ndarray
        Response function of the CSF tissue, for each bvals.
    sphere : `dipy.core.Sphere` instance, optional
        Sphere where the signal will be evaluated.
    tol : int, optional
        Tolerance gap for b-values clustering.
    btens : can be any of two options, optional
        1. an array of strings of shape (N,) specifying
           encoding tensor shape associated with all unique b-values
           separately. N corresponds to the number of unique b-values,
           including the b0. Options for elements in array: 'LTE',
           'PTE', 'STE', 'CTE' corresponding to linear, planar, spherical, and
           "cigar-shaped" tensor encoding.
        2. an array of shape (N,3,3) specifying the b-tensor of each unique
           b-values exactly. N corresponds to the number of unique b-values,
           including the b0.
    Returns
    -------
    MultiShellResponse
        MultiShellResponse object.
    """
    bvals = np.array(bvals, copy=True)
    if btens is None:
        btens = np.repeat(["LTE"], len(bvals))
    elif len(btens) != len(bvals):
        msg = """bvals and btens parameters must have the same dimension."""
        raise ValueError(msg)
    evecs = np.zeros((3, 3))
    z = np.array([0, 0, 1.0])
    evecs[:, 0] = z
    evecs[:2, 1:] = np.eye(2)
    l_values = np.arange(0, sh_order_max + 1, 2)
    m_values = np.zeros_like(l_values)
    if sphere is None:
        sphere = default_sphere
    big_sphere = sphere.subdivide()
    theta, phi = big_sphere.theta, big_sphere.phi
    B = shm.real_sh_descoteaux_from_index(
        m_values, l_values, theta[:, None], phi[:, None]
    )
    A = shm.real_sh_descoteaux_from_index(0, 0, 0, 0)
    response = np.empty([len(bvals), len(l_values) + 2])
    if bvals[0] < tol:
        gtab = GradientTable(big_sphere.vertices * 0, btens=btens[0])
        wm_response = single_tensor(
            gtab, wm_rf[0, 3], evals=wm_rf[0, :3], evecs=evecs, snr=None
        )
        response[0, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0]
        response[0, 1] = gm_rf[0, 3] / A
        response[0, 0] = csf_rf[0, 3] / A
        for i, bvalue in enumerate(bvals[1:]):
            gtab = GradientTable(big_sphere.vertices * bvalue, btens=btens[i + 1])
            wm_response = single_tensor(
                gtab, wm_rf[i, 3], evals=wm_rf[i, :3], evecs=evecs, snr=None
            )
            response[i + 1, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0]
            response[i + 1, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A
            response[i + 1, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A
        S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]]
    else:
        warnings.warn(
            """No b0 given. Proceeding either way.""", UserWarning, stacklevel=2
        )
        for i, bvalue in enumerate(bvals):
            gtab = GradientTable(big_sphere.vertices * bvalue, btens=btens[i])
            wm_response = single_tensor(
                gtab, wm_rf[i, 3], evals=wm_rf[i, :3], evecs=evecs, snr=None
            )
            response[i, 2:] = np.linalg.lstsq(B, wm_response, rcond=None)[0]
            response[i, 1] = gm_rf[i, 3] * np.exp(-bvalue * gm_rf[i, 0]) / A
            response[i, 0] = csf_rf[i, 3] * np.exp(-bvalue * csf_rf[i, 0]) / A
        S0 = [csf_rf[0, 3], gm_rf[0, 3], wm_rf[0, 3]]
    return MultiShellResponse(response, sh_order_max, bvals, S0=S0) 
[docs]
@warning_for_keywords()
def mask_for_response_msmt(
    gtab,
    data,
    *,
    roi_center=None,
    roi_radii=10,
    wm_fa_thr=0.7,
    gm_fa_thr=0.2,
    csf_fa_thr=0.1,
    gm_md_thr=0.0007,
    csf_md_thr=0.002,
):
    """Computation of masks for multi-shell multi-tissue (msmt) response
        function using FA and MD.
    Parameters
    ----------
    gtab : GradientTable
        Gradient table.
    data : ndarray
        diffusion data (4D)
    roi_center : array-like, (3,)
        Center of ROI in data. If center is None, it is assumed that it is
        the center of the volume with shape `data.shape[:3]`.
    roi_radii : int or array-like, (3,)
        radii of cuboid ROI
    wm_fa_thr : float
        FA threshold for WM.
    gm_fa_thr : float
        FA threshold for GM.
    csf_fa_thr : float
        FA threshold for CSF.
    gm_md_thr : float
        MD threshold for GM.
    csf_md_thr : float
        MD threshold for CSF.
    Returns
    -------
    mask_wm : ndarray
        Mask of voxels within the ROI and with FA above the FA threshold
        for WM.
    mask_gm : ndarray
        Mask of voxels within the ROI and with FA below the FA threshold
        for GM and with MD below the MD threshold for GM.
    mask_csf : ndarray
        Mask of voxels within the ROI and with FA below the FA threshold
        for CSF and with MD below the MD threshold for CSF.
    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. This function aims to accomplish that by
    returning a mask of voxels within a ROI and who respect some threshold
    constraints, for each tissue. More precisely, the WM mask must have a FA
    value above a given threshold. The GM mask and CSF mask must have a FA
    below given thresholds and a MD below other thresholds. To get the FA and
    MD, we need to fit a Tensor model to the datasets.
    """
    if len(data.shape) < 4:
        msg = """Data must be 4D (3D image + directions). To use a 2D image,
        please reshape it into a (N, N, 1, ndirs) array."""
        raise ValueError(msg)
    if isinstance(roi_radii, numbers.Number):
        roi_radii = (roi_radii, roi_radii, roi_radii)
    if roi_center is None:
        roi_center = np.array(data.shape[:3]) // 2
    roi_radii = _roi_in_volume(
        data.shape, np.asarray(roi_center), np.asarray(roi_radii)
    )
    roi_mask = _mask_from_roi(data.shape[:3], roi_center, roi_radii)
    list_bvals = unique_bvals_tolerance(gtab.bvals)
    if not np.all(list_bvals <= 1200):
        msg_bvals = """Some b-values are higher than 1200.
        The DTI fit might be affected."""
        warnings.warn(msg_bvals, UserWarning, stacklevel=2)
    ten = TensorModel(gtab)
    tenfit = ten.fit(data, mask=roi_mask)
    fa = fractional_anisotropy(tenfit.evals)
    fa[np.isnan(fa)] = 0
    md = mean_diffusivity(tenfit.evals)
    md[np.isnan(md)] = 0
    mask_wm = np.zeros(fa.shape, dtype=np.int64)
    mask_wm[fa > wm_fa_thr] = 1
    mask_wm *= roi_mask
    md_mask_gm = np.zeros(md.shape, dtype=np.int64)
    md_mask_gm[(md < gm_md_thr)] = 1
    fa_mask_gm = np.zeros(fa.shape, dtype=np.int64)
    fa_mask_gm[(fa < gm_fa_thr) & (fa > 0)] = 1
    mask_gm = md_mask_gm * fa_mask_gm
    mask_gm *= roi_mask
    md_mask_csf = np.zeros(md.shape, dtype=np.int64)
    md_mask_csf[(md < csf_md_thr) & (md > 0)] = 1
    fa_mask_csf = np.zeros(fa.shape, dtype=np.int64)
    fa_mask_csf[(fa < csf_fa_thr) & (fa > 0)] = 1
    mask_csf = md_mask_csf * fa_mask_csf
    mask_csf *= roi_mask
    msg = """No voxel with a {0} than {1} were found.
    Try a larger roi or a {2} threshold for {3}."""
    if np.sum(mask_wm) == 0:
        msg_fa = msg.format("FA higher", str(wm_fa_thr), "lower FA", "WM")
        warnings.warn(msg_fa, UserWarning, stacklevel=2)
    if np.sum(mask_gm) == 0:
        msg_fa = msg.format("FA lower", str(gm_fa_thr), "higher FA", "GM")
        msg_md = msg.format("MD lower", str(gm_md_thr), "higher MD", "GM")
        warnings.warn(msg_fa, UserWarning, stacklevel=2)
        warnings.warn(msg_md, UserWarning, stacklevel=2)
    if np.sum(mask_csf) == 0:
        msg_fa = msg.format("FA lower", str(csf_fa_thr), "higher FA", "CSF")
        msg_md = msg.format("MD lower", str(csf_md_thr), "higher MD", "CSF")
        warnings.warn(msg_fa, UserWarning, stacklevel=2)
        warnings.warn(msg_md, UserWarning, stacklevel=2)
    return mask_wm, mask_gm, mask_csf 
[docs]
@warning_for_keywords()
def response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, *, tol=20):
    """Computation of multi-shell multi-tissue (msmt) response
        functions from given tissues masks.
    Parameters
    ----------
    gtab : GradientTable
        Gradient table.
    data : ndarray
        diffusion data
    mask_wm : ndarray
        mask from where to compute the WM response function.
    mask_gm : ndarray
        mask from where to compute the GM response function.
    mask_csf : ndarray
        mask from where to compute the CSF response function.
    tol : int
        tolerance gap for b-values clustering. (Default = 20)
    Returns
    -------
    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for WM for each unique bvalues (except b0).
    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for GM for each unique bvalues (except b0).
    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for CSF for each unique bvalues (except b0).
    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. This information can be obtained by using
    mcsd.mask_for_response_msmt() through masks of selected voxels. The present
    function uses such masks to compute the msmt response functions.
    For the responses, we base our approach on the function
    csdeconv.response_from_mask_ssst(), with the added layers of multishell and
    multi-tissue (see the ssst function for more information about the
    computation of the ssst response function). This means that for each tissue
    we use the previously found masks and loop on them. For each mask, we loop
    on the b-values (clustered using the tolerance gap) to get many responses
    and then average them to get one response per tissue.
    """
    bvals = gtab.bvals
    bvecs = gtab.bvecs
    btens = gtab.btens
    list_bvals = unique_bvals_tolerance(bvals, tol=tol)
    b0_indices = get_bval_indices(bvals, list_bvals[0], tol=tol)
    b0_map = np.mean(data[..., b0_indices], axis=-1)[..., np.newaxis]
    masks = [mask_wm, mask_gm, mask_csf]
    tissue_responses = []
    for mask in masks:
        responses = []
        for bval in list_bvals[1:]:
            indices = get_bval_indices(bvals, bval, tol=tol)
            bvecs_sub = np.concatenate([[bvecs[b0_indices[0]]], bvecs[indices]])
            bvals_sub = np.concatenate([[0], bvals[indices]])
            if btens is not None:
                btens_b0 = btens[b0_indices[0]].reshape((1, 3, 3))
                btens_sub = np.concatenate([btens_b0, btens[indices]])
            else:
                btens_sub = None
            data_conc = np.concatenate([b0_map, data[..., indices]], axis=3)
            gtab = gradient_table(bvals_sub, bvecs=bvecs_sub, btens=btens_sub)
            response, _ = response_from_mask_ssst(gtab, data_conc, mask)
            responses.append(list(np.concatenate([response[0], [response[1]]])))
        tissue_responses.append(list(responses))
    wm_response = np.asarray(tissue_responses[0])
    gm_response = np.asarray(tissue_responses[1])
    csf_response = np.asarray(tissue_responses[2])
    return wm_response, gm_response, csf_response 
[docs]
@warning_for_keywords()
def auto_response_msmt(
    gtab,
    data,
    *,
    tol=20,
    roi_center=None,
    roi_radii=10,
    wm_fa_thr=0.7,
    gm_fa_thr=0.3,
    csf_fa_thr=0.15,
    gm_md_thr=0.001,
    csf_md_thr=0.0032,
):
    """Automatic estimation of multi-shell multi-tissue (msmt) response
        functions using FA and MD.
    Parameters
    ----------
    gtab : GradientTable
        Gradient table.
    data : ndarray
        diffusion data
    tol : int, optional
        Tolerance gap for b-values clustering.
    roi_center : array-like, (3,)
        Center of ROI in data. If center is None, it is assumed that it is
        the center of the volume with shape `data.shape[:3]`.
    roi_radii : int or array-like, (3,)
        radii of cuboid ROI
    wm_fa_thr : float
        FA threshold for WM.
    gm_fa_thr : float
        FA threshold for GM.
    csf_fa_thr : float
        FA threshold for CSF.
    gm_md_thr : float
        MD threshold for GM.
    csf_md_thr : float
        MD threshold for CSF.
    Returns
    -------
    response_wm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for WM for each unique bvalues (except b0).
    response_gm : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for GM for each unique bvalues (except b0).
    response_csf : ndarray, (len(unique_bvals_tolerance(gtab.bvals))-1, 4)
        (`evals`, `S0`) for CSF for each unique bvalues (except b0).
    Notes
    -----
    In msmt-CSD there is an important pre-processing step: the estimation of
    every tissue's response function. In order to do this, we look for voxels
    corresponding to WM, GM and CSF. We get this information from
    mcsd.mask_for_response_msmt(), which returns masks of selected voxels
    (more details are available in the description of the function).
    With the masks, we compute the response functions by using
    mcsd.response_from_mask_msmt(), which returns the `response` for each
    tissue (more details are available in the description of the function).
    """
    list_bvals = unique_bvals_tolerance(gtab.bvals)
    if not np.all(list_bvals <= 1200):
        msg_bvals = """Some b-values are higher than 1200.
        The DTI fit might be affected. It is advised to use
        mask_for_response_msmt with bvalues lower than 1200, followed by
        response_from_mask_msmt with all bvalues to overcome this."""
        warnings.warn(msg_bvals, UserWarning, stacklevel=2)
    mask_wm, mask_gm, mask_csf = mask_for_response_msmt(
        gtab,
        data,
        roi_center=roi_center,
        roi_radii=roi_radii,
        wm_fa_thr=wm_fa_thr,
        gm_fa_thr=gm_fa_thr,
        csf_fa_thr=csf_fa_thr,
        gm_md_thr=gm_md_thr,
        csf_md_thr=csf_md_thr,
    )
    response_wm, response_gm, response_csf = response_from_mask_msmt(
        gtab, data, mask_wm, mask_gm, mask_csf, tol=tol
    )
    return response_wm, response_gm, response_csf