Source code for dipy.reconst.sfm

"""
The Sparse Fascicle Model.

This is an implementation of the sparse fascicle model described in
:footcite:t:`Rokem2015`. The multi b-value version of this model is described
in :footcite:t:`Rokem2014`.

References
----------
.. footbibliography::
"""

from collections import OrderedDict
import gc
import warnings

import numpy as np

try:
    from numpy import nanmean
except ImportError:
    from scipy.stats import nanmean

import dipy.core.gradients as grad
from dipy.core.onetime import auto_attr
import dipy.core.optimize as opt
import dipy.data as dpd
from dipy.reconst.base import ReconstFit, ReconstModel
from dipy.reconst.cache import Cache
import dipy.sims.voxel as sims
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.multiproc import determine_num_processes
from dipy.utils.optpkg import optional_package

joblib, has_joblib, _ = optional_package("joblib")
sklearn, has_sklearn, _ = optional_package("sklearn")
lm, _, _ = optional_package("sklearn.linear_model")


# Isotropic signal models: these are models of the part of the signal that
# changes with b-value, but does not change with direction. This collection is
# extensible, by inheriting from IsotropicModel/IsotropicFit below:


# First, a helper function to derive the fit signal for these models:
@warning_for_keywords()
def _to_fit_iso(data, gtab, *, mask=None):
    if mask is None:
        mask = np.ones(data.shape[:-1], dtype=bool)
    # Turn it into a 2D thing:
    if len(mask.shape) > 0:
        data = data[mask]
    else:
        # This handles the corner case of fitting a single voxel:
        data = data.reshape((-1, data.shape[0]))
    data_no_b0 = data[:, ~gtab.b0s_mask]
    nzb0 = data_no_b0 > 0
    nzb0_idx = np.where(nzb0)
    zb0_idx = np.where(~nzb0)
    if np.sum(gtab.b0s_mask) > 0:
        s0 = np.mean(data[:, gtab.b0s_mask], -1)
        to_fit = np.empty(data_no_b0.shape)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            to_fit[nzb0_idx] = data_no_b0[nzb0_idx] / s0[nzb0_idx[0]]
        to_fit[zb0_idx] = 0
    else:
        to_fit = data_no_b0
    return to_fit


[docs] class IsotropicModel(ReconstModel): """ A base-class for the representation of isotropic signals. The default behavior, suitable for single b-value data is to calculate the mean in each voxel as an estimate of the signal that does not depend on direction. """ def __init__(self, gtab): """Initialize an IsotropicModel. Parameters ---------- gtab : a GradientTable class instance """ ReconstModel.__init__(self, gtab)
[docs] @warning_for_keywords() def fit(self, data, *, mask=None, **kwargs): """Fit an IsotropicModel. This boils down to finding the mean diffusion-weighted signal in each voxel Parameters ---------- data : ndarray Returns ------- IsotropicFit class instance. """ # This returns as a 2D thing: params = np.mean(_to_fit_iso(data, self.gtab, mask=mask), -1) if mask is None: params = np.reshape(params, data.shape[:-1]) else: out_params = np.zeros(data.shape[:-1]) out_params[mask] = params params = out_params return IsotropicFit(self, params)
[docs] class IsotropicFit(ReconstFit): """ A fit object for representing the isotropic signal as the mean of the diffusion-weighted signal. """ def __init__(self, model, params): """Initialize an IsotropicFit object. Parameters ---------- model : IsotropicModel class instance Isotropic model. params : ndarray The mean isotropic model parameters (the mean diffusion-weighted signal in each voxel). n_vox : int The number of voxels for which the fit was done. """ super().__init__(self, model) self.model = model self.params = params
[docs] @warning_for_keywords() def predict(self, *, gtab=None): """Predict the isotropic signal. Based on a gradient table. In this case, the (naive!) prediction will be the mean of the diffusion-weighted signal in the voxels. Parameters ---------- gtab : a GradientTable class instance, optional Defaults to use the gtab from the IsotropicModel from which this fit was derived. """ if gtab is None: gtab = self.model.gtab if len(self.params.shape) == 0: return self.params[..., np.newaxis] + np.zeros(np.sum(~gtab.b0s_mask)) else: return self.params[..., np.newaxis] + np.zeros( self.params.shape + (np.sum(~gtab.b0s_mask),) )
[docs] class ExponentialIsotropicModel(IsotropicModel): """ Representing the isotropic signal as a fit to an exponential decay function with b-values """
[docs] @warning_for_keywords() def fit(self, data, *, mask=None, **kwargs): """ Parameters ---------- data : ndarray mask : array, optional A boolean array used to mark the coordinates in the data that should be analyzed. Has the shape `data.shape[:-1]`. Default: None, which implies that all points should be analyzed. Returns ------- ExponentialIsotropicFit class instance. """ to_fit = _to_fit_iso(data, self.gtab, mask=mask) # Fitting to the log-transformed relative data is much faster: nz_idx = to_fit > 0 to_fit[nz_idx] = np.log(to_fit[nz_idx]) to_fit[~nz_idx] = -np.inf params = -nanmean(to_fit / self.gtab.bvals[~self.gtab.b0s_mask], -1) if mask is None: params = np.reshape(params, data.shape[:-1]) else: out_params = np.zeros(data.shape[:-1]) out_params[mask] = params params = out_params return ExponentialIsotropicFit(self, params)
[docs] class ExponentialIsotropicFit(IsotropicFit): """ A fit to the ExponentialIsotropicModel object, based on data. """
[docs] @warning_for_keywords() def predict(self, *, gtab=None): """ Predict the isotropic signal, based on a gradient table. In this case, the prediction will be for an exponential decay with the mean diffusivity derived from the data that was fit. Parameters ---------- gtab : a GradientTable class instance, optional Defaults to use the gtab from the IsotropicModel from which this fit was derived. """ if gtab is None: gtab = self.model.gtab if len(self.params.shape) == 0: return np.exp( -gtab.bvals[~gtab.b0s_mask] * (np.zeros(np.sum(~gtab.b0s_mask)) + self.params[..., np.newaxis]) ) else: return np.exp( -gtab.bvals[~gtab.b0s_mask] * ( np.zeros((self.params.shape[0], np.sum(~gtab.b0s_mask))) + self.params[..., np.newaxis] ) )
[docs] @warning_for_keywords() def sfm_design_matrix(gtab, sphere, response, *, mode="signal"): """ Construct the SFM design matrix Parameters ---------- gtab : GradientTable or Sphere Sets the rows of the matrix, if the mode is 'signal', this should be a GradientTable. If mode is 'odf' this should be a Sphere. sphere : Sphere Sets the columns of the matrix response : list of 3 elements The eigenvalues of a tensor which will serve as a kernel function. mode : str {'signal' | 'odf'}, optional Choose the (default) 'signal' for a design matrix containing predicted signal in the measurements defined by the gradient table for putative fascicles oriented along the vertices of the sphere. Otherwise, choose 'odf' for an odf convolution matrix, with values of the odf calculated from a tensor with the provided response eigenvalues, evaluated at the b-vectors in the gradient table, for the tensors with principal diffusion directions along the vertices of the sphere. Returns ------- mat : ndarray A design matrix that can be used for one of the following operations: when the 'signal' mode is used, each column contains the putative signal in each of the bvectors of the `gtab` if a fascicle is oriented in the direction encoded by the sphere vertex corresponding to this column. This is used for deconvolution with a measured DWI signal. If the 'odf' mode is chosen, each column instead contains the values of the tensor ODF for a tensor with a principal diffusion direction corresponding to this vertex. This is used to generate odfs from the fits of the SFM for the purpose of tracking. Examples -------- >>> import dipy.data as dpd >>> data, gtab = dpd.dsi_voxels() >>> sphere = dpd.get_sphere() >>> from dipy.reconst.sfm import sfm_design_matrix A canonical tensor approximating corpus-callosum voxels :footcite:p`Rokem2014`: >>> tensor_matrix = sfm_design_matrix(gtab, sphere, ... [0.0015, 0.0005, 0.0005]) A 'stick' function :footcite:p`Behrens2007`: >>> stick_matrix = sfm_design_matrix(gtab, sphere, [0.001, 0, 0]) References ---------- .. footbibliography:: """ if mode == "signal": with warnings.catch_warnings(): warnings.simplefilter("ignore") mat_gtab = grad.gradient_table( gtab.bvals[~gtab.b0s_mask], bvecs=gtab.bvecs[~gtab.b0s_mask] ) # Preallocate: mat = np.empty((np.sum(~gtab.b0s_mask), sphere.vertices.shape[0])) elif mode == "odf": mat = np.empty((gtab.x.shape[0], sphere.vertices.shape[0])) # Calculate column-wise: for ii, this_dir in enumerate(sphere.vertices): # Rotate the canonical tensor towards this vertex and calculate the # signal you would have gotten in the direction if mode == "signal": # For regressors based on the single tensor, remove $e^{-bD}$ mat[:, ii] = sims.single_tensor( mat_gtab, evals=response, evecs=sims.all_tensor_evecs(this_dir) ) - np.exp(-mat_gtab.bvals * np.mean(response)) elif mode == "odf": # Stick function if response[1] == 0 or response[2] == 0: mat[sphere.find_closest(sims.all_tensor_evecs(this_dir)[0]), ii] = 1 else: mat[:, ii] = sims.single_tensor_odf( gtab.vertices, evals=response, evecs=sims.all_tensor_evecs(this_dir) ) return mat
[docs] class SparseFascicleModel(ReconstModel, Cache): @warning_for_keywords() def __init__( self, gtab, *, sphere=None, response=(0.0015, 0.0005, 0.0005), solver="ElasticNet", l1_ratio=0.5, alpha=0.001, isotropic=None, seed=42, ): """ Initialize a Sparse Fascicle Model Parameters ---------- gtab : GradientTable class instance Gradient table. sphere : Sphere class instance, optional A sphere on which coefficients will be estimated. Default: symmetric sphere with 362 points (from :mod:`dipy.data`). response : (3,) array-like, optional The eigenvalues of a canonical tensor to be used as the response function of single-fascicle signals. solver : string, or initialized linear model object. This will determine the algorithm used to solve the set of linear equations underlying this model. If it is a string it needs to be one of the following: {'ElasticNet', 'NNLS'}. Otherwise, it can be an object that inherits from `dipy.optimize.SKLearnLinearSolver` or an object with a similar interface from Scikit Learn: `sklearn.linear_model.ElasticNet`, `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge` and other objects that inherit from `sklearn.base.RegressorMixin`. l1_ratio : float, optional Sets the balance between L1 and L2 regularization in ElasticNet :footcite:p`Zou2005`. alpha : float, optional Sets the balance between least-squares error and L1/L2 regularization in ElasticNet :footcite:p`Zou2005`. isotropic : IsotropicModel class instance This is a class that implements the function that calculates the value of the isotropic signal. This is a value of the signal that is independent of direction, and therefore removed from both sides of the SFM equation. The default is an instance of IsotropicModel, but other functions can be inherited from IsotropicModel to implement other fits to the aspects of the data that depend on b-value, but not on direction. Notes ----- This is an implementation of the SFM, described in :footcite:p`Rokem2015`. References ---------- .. footbibliography:: """ ReconstModel.__init__(self, gtab) if sphere is None: sphere = dpd.get_sphere() self.sphere = sphere self.response = np.asarray(response) if isotropic is None: isotropic = IsotropicModel self.isotropic = isotropic if solver == "ElasticNet": self.solver = lm.ElasticNet( l1_ratio=l1_ratio, alpha=alpha, positive=True, warm_start=False, random_state=seed, ) elif solver in ("NNLS", "nnls"): self.solver = opt.NonNegativeLeastSquares() elif ( isinstance(solver, opt.SKLearnLinearSolver) or has_sklearn and isinstance(solver, sklearn.base.RegressorMixin) ): self.solver = solver else: # If sklearn is unavailable, we can fall back on nnls (but we also # warn the user that we are about to do that): if not has_sklearn: w = sklearn._msg + "\nAlternatively, you can use 'nnls' method " w += "to fit the SparseFascicleModel" warnings.warn(w, stacklevel=2) e_s = "The `solver` key-word argument needs to be: " e_s += "'ElasticNet', 'NNLS', or a " e_s += "`dipy.optimize.SKLearnLinearSolver` object" raise ValueError(e_s)
[docs] @auto_attr def design_matrix(self): """ The design matrix for a SFM. Returns ------- ndarray The design matrix, where each column is a rotated version of the response function. """ return sfm_design_matrix(self.gtab, self.sphere, self.response, mode="signal")
@warning_for_keywords() def _fit_solver2voxels(self, isopredict, vox_data, vox, *, parallel=False): # In voxels in which S0 is 0, we just want to keep the # parameters at all-zeros, and avoid nasty sklearn errors: if not (np.any(~np.isfinite(vox_data)) or np.all(vox_data == 0)): with warnings.catch_warnings(): warnings.simplefilter("ignore") if parallel: coef = { vox: self.solver.fit( self.design_matrix, vox_data - isopredict[vox] ).coef_ } else: coef = self.solver.fit( self.design_matrix, vox_data - isopredict[vox] ).coef_ else: if parallel: return {vox: np.zeros(self.design_matrix.shape[-1])} else: return np.zeros(self.design_matrix.shape[-1]) return coef
[docs] @warning_for_keywords() def fit( self, data, *, mask=None, num_processes=1, parallel_backend="multiprocessing" ): """ Fit the SparseFascicleModel object to data. Parameters ---------- data : array The measured signal. mask : array, optional A boolean array used to mark the coordinates in the data that should be analyzed. Has the shape `data.shape[:-1]`. Default: None, which implies that all points should be analyzed. num_processes : int, optional Split the `fit` calculation to a pool of children processes using joblib. This only applies to 4D `data` arrays. Default is 1, which does not require joblib and will run `fit` serially. If < 0 the maximal number of cores minus ``num_processes + 1`` is used (enter -1 to use as many cores as possible). 0 raises an error. parallel_backend: str, ParallelBackendBase instance or None Specify the parallelization backend implementation. Supported backends are: - "loky" used by default, can induce some communication and memory overhead when exchanging input and output data with the worker Python processes. - "multiprocessing" previous process-based backend based on `multiprocessing.Pool`. Less robust than `loky`. - "threading" is a very low-overhead backend but it suffers from the Python Global Interpreter Lock if the called function relies a lot on Python objects. "threading" is mostly useful when the execution bottleneck is a compiled extension that explicitly releases the GIL (for instance a Cython loop wrapped in a "with nogil" block or an expensive call to a library such as NumPy). Returns ------- SparseFascicleFit object """ if mask is None: # Flatten it to 2D either way: data_in_mask = np.reshape(data, (-1, data.shape[-1])) else: # Check for valid shape of the mask if mask.shape != data.shape[:-1]: raise ValueError("Mask is not the same shape as data.") mask = np.asarray(mask, dtype=bool) data_in_mask = np.reshape(data[mask], (-1, data.shape[-1])) # Fitting is done on the relative signal (S/S0): flat_S0 = np.mean(data_in_mask[..., self.gtab.b0s_mask], -1) if not flat_S0.size or not flat_S0.max(): flat_S = np.zeros(data_in_mask[..., ~self.gtab.b0s_mask].shape) else: flat_S = data_in_mask[..., ~self.gtab.b0s_mask] / flat_S0[..., None] isotropic = self.isotropic(self.gtab).fit(data, mask=mask) flat_params = np.zeros((data_in_mask.shape[0], self.design_matrix.shape[-1])) del data_in_mask gc.collect() isopredict = isotropic.predict() if mask is None: isopredict = np.reshape(isopredict, (-1, isopredict.shape[-1])) else: isopredict = isopredict[mask] if not num_processes: num_processes = determine_num_processes(num_processes) if num_processes > 1 and has_joblib: with joblib.Parallel( n_jobs=num_processes, backend=parallel_backend, mmap_mode="r+" ) as parallel: out = parallel( joblib.delayed(self._fit_solver2voxels)( isopredict, vox_data, vox, True ) for vox, vox_data in enumerate(flat_S) ) del parallel flat_params_dict = {} for d in out: flat_params_dict.update(d) flat_params = np.concatenate( [ np.array(i).reshape(1, flat_params.shape[1]) for i in list( OrderedDict( sorted(flat_params_dict.items(), key=lambda x: int(x[0])) ).values() ) ] ) else: for vox, vox_data in enumerate(flat_S): flat_params[vox] = self._fit_solver2voxels( isopredict, vox_data, vox, parallel=False ) del isopredict, flat_S gc.collect() if mask is None: out_shape = data.shape[:-1] + (-1,) beta = flat_params.reshape(out_shape) S0 = flat_S0.reshape(data.shape[:-1]) else: beta = np.zeros(data.shape[:-1] + (self.design_matrix.shape[-1],)) beta[mask, :] = flat_params S0 = np.zeros(data.shape[:-1]) S0[mask] = flat_S0 return SparseFascicleFit(self, beta, S0, isotropic)
[docs] class SparseFascicleFit(ReconstFit): def __init__(self, model, beta, S0, iso): """ Initialize a SparseFascicleFit class instance Parameters ---------- model : a SparseFascicleModel object. beta : ndarray The parameters of fit to data. S0 : ndarray The mean non-diffusion-weighted signal. iso : IsotropicFit class instance A representation of the isotropic signal, together with parameters of the isotropic signal in each voxel, that is capable of deriving/predicting an isotropic signal, based on a gradient-table. """ super().__init__(self, model) self.model = model self.beta = beta self.S0 = S0 self.iso = iso
[docs] def odf(self, sphere): """ The orientation distribution function of the SFM Parameters ---------- sphere : Sphere The points in which the ODF is evaluated Returns ------- odf : ndarray of shape (x, y, z, sphere.vertices.shape[0]) """ odf_matrix = self.model.cache_get("odf_matrix", key=sphere) if odf_matrix is None: odf_matrix = sfm_design_matrix( sphere, self.model.sphere, self.model.response, mode="odf" ) self.model.cache_set("odf_matrix", key=sphere, value=odf_matrix) return np.dot( odf_matrix, self.beta.reshape(-1, self.beta.shape[-1]).T ).T.reshape(self.beta.shape[:-1] + (odf_matrix.shape[0],))
[docs] @warning_for_keywords() def predict(self, *, gtab=None, response=None, S0=None): """ Predict the signal based on the SFM parameters Parameters ---------- gtab : GradientTable, optional The bvecs/bvals to predict the signal on. Default: the gtab from the model object. response : list of 3 elements, optional The eigenvalues of a tensor which will serve as a kernel function. Default: the response of the model object. Default to use `model.response`. S0 : float or array, optional The non-diffusion-weighted signal. Default: use the S0 of the data Returns ------- pred_sig : ndarray The signal predicted in each voxel/direction """ if response is None: response = self.model.response if gtab is None: _matrix = self.model.design_matrix gtab = self.model.gtab # The only thing we can't change at this point is the sphere we use # (which sets the width of our design matrix): else: _matrix = sfm_design_matrix(gtab, self.model.sphere, response) # Get them all at once: pred_weighted = np.dot( _matrix, self.beta.reshape(-1, self.beta.shape[-1]).T ).T.reshape(self.beta.shape[:-1] + (_matrix.shape[0],)) if S0 is None: S0 = self.S0 if isinstance(S0, np.ndarray): S0 = S0[..., None] pre_pred_sig = S0 * ( pred_weighted + self.iso.predict(gtab=gtab).reshape(pred_weighted.shape) ) pred_sig = np.zeros(pre_pred_sig.shape[:-1] + (gtab.bvals.shape[0],)) pred_sig[..., ~gtab.b0s_mask] = pre_pred_sig pred_sig[..., gtab.b0s_mask] = S0 return pred_sig.squeeze()