"""
The Sparse Fascicle Model.
This is an implementation of the sparse fascicle model described in
:footcite:t:`Rokem2015`. The multi b-value version of this model is described
in :footcite:t:`Rokem2014`.
References
----------
.. footbibliography::
"""
from collections import OrderedDict
import gc
import warnings
import numpy as np
try:
from numpy import nanmean
except ImportError:
from scipy.stats import nanmean
import dipy.core.gradients as grad
from dipy.core.onetime import auto_attr
import dipy.core.optimize as opt
import dipy.data as dpd
from dipy.reconst.base import ReconstFit, ReconstModel
from dipy.reconst.cache import Cache
import dipy.sims.voxel as sims
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.multiproc import determine_num_processes
from dipy.utils.optpkg import optional_package
joblib, has_joblib, _ = optional_package("joblib")
sklearn, has_sklearn, _ = optional_package("sklearn")
lm, _, _ = optional_package("sklearn.linear_model")
# Isotropic signal models: these are models of the part of the signal that
# changes with b-value, but does not change with direction. This collection is
# extensible, by inheriting from IsotropicModel/IsotropicFit below:
# First, a helper function to derive the fit signal for these models:
@warning_for_keywords()
def _to_fit_iso(data, gtab, *, mask=None):
if mask is None:
mask = np.ones(data.shape[:-1], dtype=bool)
# Turn it into a 2D thing:
if len(mask.shape) > 0:
data = data[mask]
else:
# This handles the corner case of fitting a single voxel:
data = data.reshape((-1, data.shape[0]))
data_no_b0 = data[:, ~gtab.b0s_mask]
nzb0 = data_no_b0 > 0
nzb0_idx = np.where(nzb0)
zb0_idx = np.where(~nzb0)
if np.sum(gtab.b0s_mask) > 0:
s0 = np.mean(data[:, gtab.b0s_mask], -1)
to_fit = np.empty(data_no_b0.shape)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
to_fit[nzb0_idx] = data_no_b0[nzb0_idx] / s0[nzb0_idx[0]]
to_fit[zb0_idx] = 0
else:
to_fit = data_no_b0
return to_fit
[docs]
class IsotropicModel(ReconstModel):
"""
A base-class for the representation of isotropic signals.
The default behavior, suitable for single b-value data is to calculate the
mean in each voxel as an estimate of the signal that does not depend on
direction.
"""
def __init__(self, gtab):
"""Initialize an IsotropicModel.
Parameters
----------
gtab : a GradientTable class instance
"""
ReconstModel.__init__(self, gtab)
[docs]
@warning_for_keywords()
def fit(self, data, *, mask=None, **kwargs):
"""Fit an IsotropicModel.
This boils down to finding the mean diffusion-weighted signal in each
voxel
Parameters
----------
data : ndarray
Returns
-------
IsotropicFit class instance.
"""
# This returns as a 2D thing:
params = np.mean(_to_fit_iso(data, self.gtab, mask=mask), -1)
if mask is None:
params = np.reshape(params, data.shape[:-1])
else:
out_params = np.zeros(data.shape[:-1])
out_params[mask] = params
params = out_params
return IsotropicFit(self, params)
[docs]
class IsotropicFit(ReconstFit):
"""
A fit object for representing the isotropic signal as the mean of the
diffusion-weighted signal.
"""
def __init__(self, model, params):
"""Initialize an IsotropicFit object.
Parameters
----------
model : IsotropicModel class instance
Isotropic model.
params : ndarray
The mean isotropic model parameters (the mean diffusion-weighted
signal in each voxel).
n_vox : int
The number of voxels for which the fit was done.
"""
super().__init__(self, model)
self.model = model
self.params = params
[docs]
@warning_for_keywords()
def predict(self, *, gtab=None):
"""Predict the isotropic signal.
Based on a gradient table. In this case, the (naive!) prediction will
be the mean of the diffusion-weighted signal in the voxels.
Parameters
----------
gtab : a GradientTable class instance, optional
Defaults to use the gtab from the IsotropicModel from which this
fit was derived.
"""
if gtab is None:
gtab = self.model.gtab
if len(self.params.shape) == 0:
return self.params[..., np.newaxis] + np.zeros(np.sum(~gtab.b0s_mask))
else:
return self.params[..., np.newaxis] + np.zeros(
self.params.shape + (np.sum(~gtab.b0s_mask),)
)
[docs]
class ExponentialIsotropicModel(IsotropicModel):
"""
Representing the isotropic signal as a fit to an exponential decay function
with b-values
"""
[docs]
@warning_for_keywords()
def fit(self, data, *, mask=None, **kwargs):
"""
Parameters
----------
data : ndarray
mask : array, optional
A boolean array used to mark the coordinates in the data that
should be analyzed. Has the shape `data.shape[:-1]`. Default: None,
which implies that all points should be analyzed.
Returns
-------
ExponentialIsotropicFit class instance.
"""
to_fit = _to_fit_iso(data, self.gtab, mask=mask)
# Fitting to the log-transformed relative data is much faster:
nz_idx = to_fit > 0
to_fit[nz_idx] = np.log(to_fit[nz_idx])
to_fit[~nz_idx] = -np.inf
params = -nanmean(to_fit / self.gtab.bvals[~self.gtab.b0s_mask], -1)
if mask is None:
params = np.reshape(params, data.shape[:-1])
else:
out_params = np.zeros(data.shape[:-1])
out_params[mask] = params
params = out_params
return ExponentialIsotropicFit(self, params)
[docs]
class ExponentialIsotropicFit(IsotropicFit):
"""
A fit to the ExponentialIsotropicModel object, based on data.
"""
[docs]
@warning_for_keywords()
def predict(self, *, gtab=None):
"""
Predict the isotropic signal, based on a gradient table. In this case,
the prediction will be for an exponential decay with the mean
diffusivity derived from the data that was fit.
Parameters
----------
gtab : a GradientTable class instance, optional
Defaults to use the gtab from the IsotropicModel from which this
fit was derived.
"""
if gtab is None:
gtab = self.model.gtab
if len(self.params.shape) == 0:
return np.exp(
-gtab.bvals[~gtab.b0s_mask]
* (np.zeros(np.sum(~gtab.b0s_mask)) + self.params[..., np.newaxis])
)
else:
return np.exp(
-gtab.bvals[~gtab.b0s_mask]
* (
np.zeros((self.params.shape[0], np.sum(~gtab.b0s_mask)))
+ self.params[..., np.newaxis]
)
)
[docs]
@warning_for_keywords()
def sfm_design_matrix(gtab, sphere, response, *, mode="signal"):
"""
Construct the SFM design matrix
Parameters
----------
gtab : GradientTable or Sphere
Sets the rows of the matrix, if the mode is 'signal', this should be a
GradientTable. If mode is 'odf' this should be a Sphere.
sphere : Sphere
Sets the columns of the matrix
response : list of 3 elements
The eigenvalues of a tensor which will serve as a kernel
function.
mode : str {'signal' | 'odf'}, optional
Choose the (default) 'signal' for a design matrix containing predicted
signal in the measurements defined by the gradient table for putative
fascicles oriented along the vertices of the sphere. Otherwise, choose
'odf' for an odf convolution matrix, with values of the odf calculated
from a tensor with the provided response eigenvalues, evaluated at the
b-vectors in the gradient table, for the tensors with principal
diffusion directions along the vertices of the sphere.
Returns
-------
mat : ndarray
A design matrix that can be used for one of the following operations:
when the 'signal' mode is used, each column contains the putative
signal in each of the bvectors of the `gtab` if a fascicle is oriented
in the direction encoded by the sphere vertex corresponding to this
column. This is used for deconvolution with a measured DWI signal. If
the 'odf' mode is chosen, each column instead contains the values of
the tensor ODF for a tensor with a principal diffusion direction
corresponding to this vertex. This is used to generate odfs from the
fits of the SFM for the purpose of tracking.
Examples
--------
>>> import dipy.data as dpd
>>> data, gtab = dpd.dsi_voxels()
>>> sphere = dpd.get_sphere()
>>> from dipy.reconst.sfm import sfm_design_matrix
A canonical tensor approximating corpus-callosum voxels
:footcite:p`Rokem2014`:
>>> tensor_matrix = sfm_design_matrix(gtab, sphere,
... [0.0015, 0.0005, 0.0005])
A 'stick' function :footcite:p`Behrens2007`:
>>> stick_matrix = sfm_design_matrix(gtab, sphere, [0.001, 0, 0])
References
----------
.. footbibliography::
"""
if mode == "signal":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mat_gtab = grad.gradient_table(
gtab.bvals[~gtab.b0s_mask], bvecs=gtab.bvecs[~gtab.b0s_mask]
)
# Preallocate:
mat = np.empty((np.sum(~gtab.b0s_mask), sphere.vertices.shape[0]))
elif mode == "odf":
mat = np.empty((gtab.x.shape[0], sphere.vertices.shape[0]))
# Calculate column-wise:
for ii, this_dir in enumerate(sphere.vertices):
# Rotate the canonical tensor towards this vertex and calculate the
# signal you would have gotten in the direction
if mode == "signal":
# For regressors based on the single tensor, remove $e^{-bD}$
mat[:, ii] = sims.single_tensor(
mat_gtab, evals=response, evecs=sims.all_tensor_evecs(this_dir)
) - np.exp(-mat_gtab.bvals * np.mean(response))
elif mode == "odf":
# Stick function
if response[1] == 0 or response[2] == 0:
mat[sphere.find_closest(sims.all_tensor_evecs(this_dir)[0]), ii] = 1
else:
mat[:, ii] = sims.single_tensor_odf(
gtab.vertices, evals=response, evecs=sims.all_tensor_evecs(this_dir)
)
return mat
[docs]
class SparseFascicleModel(ReconstModel, Cache):
@warning_for_keywords()
def __init__(
self,
gtab,
*,
sphere=None,
response=(0.0015, 0.0005, 0.0005),
solver="ElasticNet",
l1_ratio=0.5,
alpha=0.001,
isotropic=None,
seed=42,
):
"""
Initialize a Sparse Fascicle Model
Parameters
----------
gtab : GradientTable class instance
Gradient table.
sphere : Sphere class instance, optional
A sphere on which coefficients will be estimated. Default:
symmetric sphere with 362 points (from :mod:`dipy.data`).
response : (3,) array-like, optional
The eigenvalues of a canonical tensor to be used as the response
function of single-fascicle signals.
solver : string, or initialized linear model object.
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ElasticNet', 'NNLS'}. Otherwise, it can be
an object that inherits from `dipy.optimize.SKLearnLinearSolver`
or an object with a similar interface from Scikit Learn:
`sklearn.linear_model.ElasticNet`, `sklearn.linear_model.Lasso` or
`sklearn.linear_model.Ridge` and other objects that inherit from
`sklearn.base.RegressorMixin`.
l1_ratio : float, optional
Sets the balance between L1 and L2 regularization in ElasticNet
:footcite:p`Zou2005`.
alpha : float, optional
Sets the balance between least-squares error and L1/L2
regularization in ElasticNet :footcite:p`Zou2005`.
isotropic : IsotropicModel class instance
This is a class that implements the function that calculates the
value of the isotropic signal. This is a value of the signal that
is independent of direction, and therefore removed from both sides
of the SFM equation. The default is an instance of IsotropicModel,
but other functions can be inherited from IsotropicModel to
implement other fits to the aspects of the data that depend on
b-value, but not on direction.
Notes
-----
This is an implementation of the SFM, described in
:footcite:p`Rokem2015`.
References
----------
.. footbibliography::
"""
ReconstModel.__init__(self, gtab)
if sphere is None:
sphere = dpd.get_sphere()
self.sphere = sphere
self.response = np.asarray(response)
if isotropic is None:
isotropic = IsotropicModel
self.isotropic = isotropic
if solver == "ElasticNet":
self.solver = lm.ElasticNet(
l1_ratio=l1_ratio,
alpha=alpha,
positive=True,
warm_start=False,
random_state=seed,
)
elif solver in ("NNLS", "nnls"):
self.solver = opt.NonNegativeLeastSquares()
elif (
isinstance(solver, opt.SKLearnLinearSolver)
or has_sklearn
and isinstance(solver, sklearn.base.RegressorMixin)
):
self.solver = solver
else:
# If sklearn is unavailable, we can fall back on nnls (but we also
# warn the user that we are about to do that):
if not has_sklearn:
w = sklearn._msg + "\nAlternatively, you can use 'nnls' method "
w += "to fit the SparseFascicleModel"
warnings.warn(w, stacklevel=2)
e_s = "The `solver` key-word argument needs to be: "
e_s += "'ElasticNet', 'NNLS', or a "
e_s += "`dipy.optimize.SKLearnLinearSolver` object"
raise ValueError(e_s)
[docs]
@auto_attr
def design_matrix(self):
"""
The design matrix for a SFM.
Returns
-------
ndarray
The design matrix, where each column is a rotated version of the
response function.
"""
return sfm_design_matrix(self.gtab, self.sphere, self.response, mode="signal")
@warning_for_keywords()
def _fit_solver2voxels(self, isopredict, vox_data, vox, *, parallel=False):
# In voxels in which S0 is 0, we just want to keep the
# parameters at all-zeros, and avoid nasty sklearn errors:
if not (np.any(~np.isfinite(vox_data)) or np.all(vox_data == 0)):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if parallel:
coef = {
vox: self.solver.fit(
self.design_matrix, vox_data - isopredict[vox]
).coef_
}
else:
coef = self.solver.fit(
self.design_matrix, vox_data - isopredict[vox]
).coef_
else:
if parallel:
return {vox: np.zeros(self.design_matrix.shape[-1])}
else:
return np.zeros(self.design_matrix.shape[-1])
return coef
[docs]
@warning_for_keywords()
def fit(
self, data, *, mask=None, num_processes=1, parallel_backend="multiprocessing"
):
"""
Fit the SparseFascicleModel object to data.
Parameters
----------
data : array
The measured signal.
mask : array, optional
A boolean array used to mark the coordinates in the data that
should be analyzed. Has the shape `data.shape[:-1]`. Default: None,
which implies that all points should be analyzed.
num_processes : int, optional
Split the `fit` calculation to a pool of children processes using
joblib. This only applies to 4D `data` arrays. Default is 1,
which does not require joblib and will run `fit` serially.
If < 0 the maximal number of cores minus ``num_processes + 1``
is used (enter -1 to use as many cores as possible).
0 raises an error.
parallel_backend: str, ParallelBackendBase instance or None
Specify the parallelization backend implementation.
Supported backends are:
- "loky" used by default, can induce some
communication and memory overhead when exchanging input and
output data with the worker Python processes.
- "multiprocessing" previous process-based backend based on
`multiprocessing.Pool`. Less robust than `loky`.
- "threading" is a very low-overhead backend but it suffers
from the Python Global Interpreter Lock if the called function
relies a lot on Python objects. "threading" is mostly useful
when the execution bottleneck is a compiled extension that
explicitly releases the GIL (for instance a Cython loop wrapped
in a "with nogil" block or an expensive call to a library such
as NumPy).
Returns
-------
SparseFascicleFit object
"""
if mask is None:
# Flatten it to 2D either way:
data_in_mask = np.reshape(data, (-1, data.shape[-1]))
else:
# Check for valid shape of the mask
if mask.shape != data.shape[:-1]:
raise ValueError("Mask is not the same shape as data.")
mask = np.asarray(mask, dtype=bool)
data_in_mask = np.reshape(data[mask], (-1, data.shape[-1]))
# Fitting is done on the relative signal (S/S0):
flat_S0 = np.mean(data_in_mask[..., self.gtab.b0s_mask], -1)
if not flat_S0.size or not flat_S0.max():
flat_S = np.zeros(data_in_mask[..., ~self.gtab.b0s_mask].shape)
else:
flat_S = data_in_mask[..., ~self.gtab.b0s_mask] / flat_S0[..., None]
isotropic = self.isotropic(self.gtab).fit(data, mask=mask)
flat_params = np.zeros((data_in_mask.shape[0], self.design_matrix.shape[-1]))
del data_in_mask
gc.collect()
isopredict = isotropic.predict()
if mask is None:
isopredict = np.reshape(isopredict, (-1, isopredict.shape[-1]))
else:
isopredict = isopredict[mask]
if not num_processes:
num_processes = determine_num_processes(num_processes)
if num_processes > 1 and has_joblib:
with joblib.Parallel(
n_jobs=num_processes, backend=parallel_backend, mmap_mode="r+"
) as parallel:
out = parallel(
joblib.delayed(self._fit_solver2voxels)(
isopredict, vox_data, vox, True
)
for vox, vox_data in enumerate(flat_S)
)
del parallel
flat_params_dict = {}
for d in out:
flat_params_dict.update(d)
flat_params = np.concatenate(
[
np.array(i).reshape(1, flat_params.shape[1])
for i in list(
OrderedDict(
sorted(flat_params_dict.items(), key=lambda x: int(x[0]))
).values()
)
]
)
else:
for vox, vox_data in enumerate(flat_S):
flat_params[vox] = self._fit_solver2voxels(
isopredict, vox_data, vox, parallel=False
)
del isopredict, flat_S
gc.collect()
if mask is None:
out_shape = data.shape[:-1] + (-1,)
beta = flat_params.reshape(out_shape)
S0 = flat_S0.reshape(data.shape[:-1])
else:
beta = np.zeros(data.shape[:-1] + (self.design_matrix.shape[-1],))
beta[mask, :] = flat_params
S0 = np.zeros(data.shape[:-1])
S0[mask] = flat_S0
return SparseFascicleFit(self, beta, S0, isotropic)
[docs]
class SparseFascicleFit(ReconstFit):
def __init__(self, model, beta, S0, iso):
"""
Initialize a SparseFascicleFit class instance
Parameters
----------
model : a SparseFascicleModel object.
beta : ndarray
The parameters of fit to data.
S0 : ndarray
The mean non-diffusion-weighted signal.
iso : IsotropicFit class instance
A representation of the isotropic signal, together with parameters
of the isotropic signal in each voxel, that is capable of
deriving/predicting an isotropic signal, based on a gradient-table.
"""
super().__init__(self, model)
self.model = model
self.beta = beta
self.S0 = S0
self.iso = iso
[docs]
def odf(self, sphere):
"""
The orientation distribution function of the SFM
Parameters
----------
sphere : Sphere
The points in which the ODF is evaluated
Returns
-------
odf : ndarray of shape (x, y, z, sphere.vertices.shape[0])
"""
odf_matrix = self.model.cache_get("odf_matrix", key=sphere)
if odf_matrix is None:
odf_matrix = sfm_design_matrix(
sphere, self.model.sphere, self.model.response, mode="odf"
)
self.model.cache_set("odf_matrix", key=sphere, value=odf_matrix)
return np.dot(
odf_matrix, self.beta.reshape(-1, self.beta.shape[-1]).T
).T.reshape(self.beta.shape[:-1] + (odf_matrix.shape[0],))
[docs]
@warning_for_keywords()
def predict(self, *, gtab=None, response=None, S0=None):
"""
Predict the signal based on the SFM parameters
Parameters
----------
gtab : GradientTable, optional
The bvecs/bvals to predict the signal on. Default: the gtab from
the model object.
response : list of 3 elements, optional
The eigenvalues of a tensor which will serve as a kernel
function. Default: the response of the model object. Default to use
`model.response`.
S0 : float or array, optional
The non-diffusion-weighted signal. Default: use the S0 of the data
Returns
-------
pred_sig : ndarray
The signal predicted in each voxel/direction
"""
if response is None:
response = self.model.response
if gtab is None:
_matrix = self.model.design_matrix
gtab = self.model.gtab
# The only thing we can't change at this point is the sphere we use
# (which sets the width of our design matrix):
else:
_matrix = sfm_design_matrix(gtab, self.model.sphere, response)
# Get them all at once:
pred_weighted = np.dot(
_matrix, self.beta.reshape(-1, self.beta.shape[-1]).T
).T.reshape(self.beta.shape[:-1] + (_matrix.shape[0],))
if S0 is None:
S0 = self.S0
if isinstance(S0, np.ndarray):
S0 = S0[..., None]
pre_pred_sig = S0 * (
pred_weighted + self.iso.predict(gtab=gtab).reshape(pred_weighted.shape)
)
pred_sig = np.zeros(pre_pred_sig.shape[:-1] + (gtab.bvals.shape[0],))
pred_sig[..., ~gtab.b0s_mask] = pre_pred_sig
pred_sig[..., gtab.b0s_mask] = S0
return pred_sig.squeeze()