from itertools import repeat
import multiprocessing as mp
from os import path
import tempfile
import warnings
import numpy as np
import scipy.optimize as opt
from dipy.core.interpolation import trilinear_interpolate4d
from dipy.core.ndindex import ndindex
from dipy.core.sphere import Sphere
from dipy.data import default_sphere
from dipy.reconst.dirspeed import peak_directions
from dipy.reconst.eudx_direction_getter import EuDXDirectionGetter
from dipy.reconst.odf import gfa
from dipy.reconst.recspeed import (
local_maxima,
remove_similar_vertices,
search_descending,
)
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.multiproc import determine_num_processes
[docs]
@warning_for_keywords()
def peak_directions_nl(
sphere_eval,
*,
relative_peak_threshold=0.25,
min_separation_angle=25,
sphere=default_sphere,
xtol=1e-7,
):
"""Non Linear Direction Finder.
Parameters
----------
sphere_eval : callable
A function which can be evaluated on a sphere.
relative_peak_threshold : float
Only return peaks greater than ``relative_peak_threshold * m`` where m
is the largest peak.
min_separation_angle : float in [0, 90]
The minimum distance between directions. If two peaks are too close
only the larger of the two is returned.
sphere : Sphere
A discrete Sphere. The points on the sphere will be used for initial
estimate of maximums.
xtol : float
Relative tolerance for optimization.
Returns
-------
directions : array (N, 3)
Points on the sphere corresponding to N local maxima on the sphere.
values : array (N,)
Value of sphere_eval at each point on directions.
"""
# Find discrete peaks for use as seeds in non-linear search
discrete_values = sphere_eval(sphere)
values, indices = local_maxima(discrete_values, sphere.edges)
seeds = np.column_stack([sphere.theta[indices], sphere.phi[indices]])
# Helper function
def _helper(x):
sphere = Sphere(theta=x[0], phi=x[1])
return -sphere_eval(sphere)
# Non-linear search
num_seeds = len(seeds)
theta = np.empty(num_seeds)
phi = np.empty(num_seeds)
for i in range(num_seeds):
peak = opt.fmin(_helper, seeds[i], xtol=xtol, disp=False)
theta[i], phi[i] = peak
# Evaluate on new-found peaks
small_sphere = Sphere(theta=theta, phi=phi)
values = sphere_eval(small_sphere)
# Sort in descending order
order = values.argsort()[::-1]
values = values[order]
directions = small_sphere.vertices[order]
# Remove directions that are too small
n = search_descending(values, relative_peak_threshold)
directions = directions[:n]
# Remove peaks too close to each-other
directions, idx = remove_similar_vertices(
directions, min_separation_angle, return_index=True
)
values = values[idx]
return directions, values
def _pam_from_attrs(
klass, sphere, peak_indices, peak_values, peak_dirs, gfa, qa, shm_coeff, B, odf
):
"""
Construct PeaksAndMetrics object (or subclass) from its attributes.
This is also useful for pickling/unpickling of these objects (see also
:func:`__reduce__` below).
Parameters
----------
klass : class
The class of object to be created.
sphere : `Sphere` class instance.
Sphere for discretization.
peak_indices : ndarray
Indices (in sphere vertices) of the peaks in each voxel.
peak_values : ndarray
The value of the peaks.
peak_dirs : ndarray
The direction of each peak.
gfa : ndarray
The Generalized Fractional Anisotropy in each voxel.
qa : ndarray
Quantitative Anisotropy in each voxel.
shm_coeff : ndarray
The coefficients of the spherical harmonic basis for the ODF in
each voxel.
B : ndarray
The spherical harmonic matrix, for multiplication with the
coefficients.
odf : ndarray
The orientation distribution function on the sphere in each voxel.
Returns
-------
pam : Instance of the class `klass`.
"""
this_pam = klass()
this_pam.sphere = sphere
this_pam.peak_dirs = peak_dirs
this_pam.peak_values = peak_values
this_pam.peak_indices = peak_indices
this_pam.gfa = gfa
this_pam.qa = qa
this_pam.shm_coeff = shm_coeff
this_pam.B = B
this_pam.odf = odf
return this_pam
[docs]
class PeaksAndMetrics(EuDXDirectionGetter):
def __reduce__(self):
return _pam_from_attrs, (
self.__class__,
self.sphere,
self.peak_indices,
self.peak_values,
self.peak_dirs,
self.gfa,
self.qa,
self.shm_coeff,
self.B,
self.odf,
)
def _peaks_from_model_parallel(
model,
data,
sphere,
relative_peak_threshold,
min_separation_angle,
mask,
return_odf,
return_sh,
gfa_thr,
normalize_peaks,
sh_order,
sh_basis_type,
legacy,
npeaks,
B,
invB,
num_processes,
):
shape = list(data.shape)
data = np.reshape(data, (-1, shape[-1]))
n = data.shape[0]
nbr_chunks = num_processes**2
chunk_size = int(np.ceil(n / nbr_chunks))
indices = list(
zip(
np.arange(0, n, chunk_size),
np.arange(0, n, chunk_size) + chunk_size,
)
)
with tempfile.TemporaryDirectory() as tmpdir:
data_file_name = path.join(tmpdir, "data.npy")
np.save(data_file_name, data)
if mask is not None:
mask = mask.flatten()
mask_file_name = path.join(tmpdir, "mask.npy")
np.save(mask_file_name, mask)
else:
mask_file_name = None
mp.set_start_method("spawn", force=True)
pool = mp.Pool(num_processes)
pam_res = pool.map(
_peaks_from_model_parallel_sub,
zip(
repeat((data_file_name, mask_file_name)),
indices,
repeat(model),
repeat(sphere),
repeat(relative_peak_threshold),
repeat(min_separation_angle),
repeat(return_odf),
repeat(return_sh),
repeat(gfa_thr),
repeat(normalize_peaks),
repeat(sh_order),
repeat(sh_basis_type),
repeat(legacy),
repeat(npeaks),
repeat(B),
repeat(invB),
),
)
pool.close()
pam = PeaksAndMetrics()
pam.sphere = sphere
# use memmap to reduce the memory usage
pam.gfa = np.memmap(
path.join(tmpdir, "gfa.npy"),
dtype=pam_res[0].gfa.dtype,
mode="w+",
shape=(data.shape[0]),
)
pam.peak_dirs = np.memmap(
path.join(tmpdir, "peak_dirs.npy"),
dtype=pam_res[0].peak_dirs.dtype,
mode="w+",
shape=(data.shape[0], npeaks, 3),
)
pam.peak_values = np.memmap(
path.join(tmpdir, "peak_values.npy"),
dtype=pam_res[0].peak_values.dtype,
mode="w+",
shape=(data.shape[0], npeaks),
)
pam.peak_indices = np.memmap(
path.join(tmpdir, "peak_indices.npy"),
dtype=pam_res[0].peak_indices.dtype,
mode="w+",
shape=(data.shape[0], npeaks),
)
pam.qa = np.memmap(
path.join(tmpdir, "qa.npy"),
dtype=pam_res[0].qa.dtype,
mode="w+",
shape=(data.shape[0], npeaks),
)
if return_sh:
nbr_shm_coeff = (sh_order + 2) * (sh_order + 1) // 2
pam.shm_coeff = np.memmap(
path.join(tmpdir, "shm.npy"),
dtype=pam_res[0].shm_coeff.dtype,
mode="w+",
shape=(data.shape[0], nbr_shm_coeff),
)
pam.B = pam_res[0].B
else:
pam.shm_coeff = None
pam.invB = None
if return_odf:
pam.odf = np.memmap(
path.join(tmpdir, "odf.npy"),
dtype=pam_res[0].odf.dtype,
mode="w+",
shape=(data.shape[0], len(sphere.vertices)),
)
else:
pam.odf = None
# copy subprocesses pam to a single pam (memmaps)
for i, (start_pos, end_pos) in enumerate(indices):
pam.gfa[start_pos:end_pos] = pam_res[i].gfa
pam.peak_dirs[start_pos:end_pos] = pam_res[i].peak_dirs
pam.peak_values[start_pos:end_pos] = pam_res[i].peak_values
pam.peak_indices[start_pos:end_pos] = pam_res[i].peak_indices
pam.qa[start_pos:end_pos] = pam_res[i].qa
if return_sh:
pam.shm_coeff[start_pos:end_pos] = pam_res[i].shm_coeff
if return_odf:
pam.odf[start_pos:end_pos] = pam_res[i].odf
# load memmaps to arrays and reshape the metric
shape[-1] = -1
pam.gfa = np.reshape(np.array(pam.gfa), shape[:-1])
pam.peak_dirs = np.reshape(np.array(pam.peak_dirs), shape + [3])
pam.peak_values = np.reshape(np.array(pam.peak_values), shape)
pam.peak_indices = np.reshape(np.array(pam.peak_indices), shape)
pam.qa = np.reshape(np.array(pam.qa), shape)
if return_sh:
pam.shm_coeff = np.reshape(np.array(pam.shm_coeff), shape)
if return_odf:
pam.odf = np.reshape(np.array(pam.odf), shape)
# Make sure all worker processes have exited before leaving context
# manager in order to prevent temporary file deletion errors in windows
pool.join()
return pam
def _peaks_from_model_parallel_sub(args):
(data_file_name, mask_file_name) = args[0]
(start_pos, end_pos) = args[1]
model = args[2]
sphere = args[3]
relative_peak_threshold = args[4]
min_separation_angle = args[5]
return_odf = args[6]
return_sh = args[7]
gfa_thr = args[8]
normalize_peaks = args[9]
sh_order = args[10]
sh_basis_type = args[11]
legacy = args[12]
npeaks = args[13]
B = args[14]
invB = args[15]
data = np.load(data_file_name, mmap_mode="r")[start_pos:end_pos]
if mask_file_name is not None:
mask = np.load(mask_file_name, mmap_mode="r")[start_pos:end_pos]
else:
mask = None
return peaks_from_model(
model,
data,
sphere,
relative_peak_threshold,
min_separation_angle,
mask=mask,
return_odf=return_odf,
return_sh=return_sh,
gfa_thr=gfa_thr,
normalize_peaks=normalize_peaks,
sh_order_max=sh_order,
sh_basis_type=sh_basis_type,
legacy=legacy,
npeaks=npeaks,
B=B,
invB=invB,
parallel=False,
num_processes=None,
)
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def peaks_from_model(
model,
data,
sphere,
relative_peak_threshold,
min_separation_angle,
*,
mask=None,
return_odf=False,
return_sh=True,
gfa_thr=0,
normalize_peaks=False,
sh_order_max=8,
sh_basis_type=None,
legacy=True,
npeaks=5,
B=None,
invB=None,
parallel=False,
num_processes=None,
):
"""Fit the model to data and computes peaks and metrics
Parameters
----------
model : a model instance
`model` will be used to fit the data.
data : ndarray
Diffusion data.
sphere : Sphere
The Sphere providing discrete directions for evaluation.
relative_peak_threshold : float
Only return peaks greater than ``relative_peak_threshold * m`` where m
is the largest peak.
min_separation_angle : float in [0, 90] The minimum distance between
directions. If two peaks are too close only the larger of the two is
returned.
mask : array, optional
If `mask` is provided, voxels that are False in `mask` are skipped and
no peaks are returned.
return_odf : bool
If True, the odfs are returned.
return_sh : bool
If True, the odf as spherical harmonics coefficients is returned
gfa_thr : float
Voxels with gfa less than `gfa_thr` are skipped, no peaks are returned.
normalize_peaks : bool
If true, all peak values are calculated relative to `max(odf)`.
sh_order_max : int, optional
Maximum SH order (l) in the SH fit. For `sh_order_max`, there
will be
``(sh_order_max + 1) * (sh_order_max + 2) / 2``
SH coefficients (default 8).
sh_basis_type : {None, 'tournier07', 'descoteaux07'}
``None`` for the default DIPY basis,
``tournier07`` for the Tournier 2007 :footcite:p:Tournier2007` basis,
and ``descoteaux07`` for the Descoteaux 2007 :footcite:p:Descoteaux2007`
basis
(``None`` defaults to ``descoteaux07``).
legacy: bool, optional
True to use a legacy basis definition for backward compatibility
with previous ``tournier07`` and ``descoteaux07`` implementations.
npeaks : int
Maximum number of peaks found (default 5 peaks).
B : ndarray, optional
Matrix that transforms spherical harmonics to spherical function
``sf = np.dot(sh, B)``.
invB : ndarray, optional
Inverse of B.
parallel: bool
If True, use multiprocessing to compute peaks and metric
(default False). Temporary files are saved in the default temporary
directory of the system. It can be changed using ``import tempfile``
and ``tempfile.tempdir = '/path/to/tempdir'``.
num_processes: int, optional
If `parallel` is True, the number of subprocesses to use
(default multiprocessing.cpu_count()). If < 0 the maximal number of
cores minus ``num_processes + 1`` is used (enter -1 to use as many
cores as possible). 0 raises an error.
Returns
-------
pam : PeaksAndMetrics
An object with ``gfa``, ``peak_directions``, ``peak_values``,
``peak_indices``, ``odf``, ``shm_coeffs`` as attributes
References
----------
.. footbibliography::
"""
if return_sh and (B is None or invB is None):
B, invB = sh_to_sf_matrix(
sphere,
sh_order_max=sh_order_max,
basis_type=sh_basis_type,
return_inv=True,
legacy=legacy,
)
num_processes = determine_num_processes(num_processes)
if parallel and num_processes > 1:
# It is mandatory to provide B and invB to the parallel function.
# Otherwise, a call to np.linalg.pinv is made in a subprocess and
# makes it timeout on some system.
# see https://github.com/dipy/dipy/issues/253 for details
return _peaks_from_model_parallel(
model,
data,
sphere,
relative_peak_threshold,
min_separation_angle,
mask,
return_odf,
return_sh,
gfa_thr,
normalize_peaks,
sh_order_max,
sh_basis_type,
legacy,
npeaks,
B,
invB,
num_processes,
)
shape = data.shape[:-1]
if mask is None:
mask = np.ones(shape, dtype="bool")
else:
if mask.shape != shape:
raise ValueError("Mask is not the same shape as data.")
gfa_array = np.zeros(shape)
qa_array = np.zeros((shape + (npeaks,)))
peak_dirs = np.zeros((shape + (npeaks, 3)))
peak_values = np.zeros((shape + (npeaks,)))
peak_indices = np.zeros((shape + (npeaks,)), dtype=np.int32)
peak_indices.fill(-1)
if return_sh:
n_shm_coeff = (sh_order_max + 2) * (sh_order_max + 1) // 2
shm_coeff = np.zeros((shape + (n_shm_coeff,)))
if return_odf:
odf_array = np.zeros((shape + (len(sphere.vertices),)))
global_max = -np.inf
for idx in ndindex(shape):
if not mask[idx]:
continue
odf = model.fit(data[idx]).odf(sphere=sphere)
if return_sh:
shm_coeff[idx] = np.dot(odf, invB)
if return_odf:
odf_array[idx] = odf
gfa_array[idx] = gfa(odf)
if gfa_array[idx] < gfa_thr:
global_max = max(global_max, odf.max())
continue
# Get peaks of odf
direction, pk, ind = peak_directions(
odf,
sphere,
relative_peak_threshold=relative_peak_threshold,
min_separation_angle=min_separation_angle,
)
# Calculate peak metrics
if pk.shape[0] != 0:
global_max = max(global_max, pk[0])
n = min(npeaks, pk.shape[0])
qa_array[idx][:n] = pk[:n] - odf.min()
peak_dirs[idx][:n] = direction[:n]
peak_indices[idx][:n] = ind[:n]
peak_values[idx][:n] = pk[:n]
if normalize_peaks:
peak_values[idx][:n] = peak_values[idx][:n] / pk[0] if pk[0] != 0 else 0
peak_dirs[idx] *= peak_values[idx][:, None]
qa_array /= global_max
return _pam_from_attrs(
PeaksAndMetrics,
sphere,
peak_indices,
peak_values,
peak_dirs,
gfa_array,
qa_array,
shm_coeff if return_sh else None,
B if return_sh else None,
odf_array if return_odf else None,
)
[docs]
def reshape_peaks_for_visualization(peaks):
"""Reshape peaks for visualization.
Reshape and convert to float32 a set of peaks for visualisation with mrtrix
or the fibernavigator.
Parameters
----------
peaks: nd array (..., N, 3) or PeaksAndMetrics object
The peaks to be reshaped and converted to float32.
Returns
-------
peaks : nd array (..., 3*N)
"""
if isinstance(peaks, PeaksAndMetrics):
peaks = peaks.peak_dirs
return peaks.reshape(np.append(peaks.shape[:-2], -1)).astype("float32")
[docs]
def peaks_from_positions(
positions,
odfs,
sphere,
affine,
*,
pmf_gen=None,
relative_peak_threshold=0.5,
min_separation_angle=25,
is_symmetric=True,
npeaks=5,
):
"""
Extract the peaks at each positions.
Parameters
----------
position : array, (N, 3)
World coordinates of the N positions.
odfs : array, (X, Y, Z, M)
Orientation distribution function (spherical function) represented
on a sphere of M points.
sphere : Sphere
A discrete Sphere. The M points on the sphere correspond to the points
of the odfs.
affine : array (4, 4)
The mapping between voxel indices and the point space for positions.
pmf_gen : PmfGen
Probability mass function generator from voxel orientation information. Replaces
``odfs`` and ``sphere`` when used.
relative_peak_threshold : float, optional
Only peaks greater than ``min + relative_peak_threshold * scale`` are
kept, where ``min = max(0, odf.min())`` and
``scale = odf.max() - min``. The ``relative_peak_threshold`` should
be in the range [0, 1].
min_separation_angle : float, optional
The minimum distance between directions. If two peaks are too close
only the larger of the two is returned. The ``min_separation_angle``
should be in the range [0, 90].
is_symmetric : bool, optional
If True, v is considered equal to -v.
npeaks : int, optional
The maximum number of peaks to extract at from each position.
Returns
-------
peaks_arr : array (N, npeaks, 3)
"""
if pmf_gen is not None and (odfs is not None or sphere is not None):
msg = (
"``odfs`` and ``sphere`` arguments will be ignored in favor of ``pmf_gen``."
)
warnings.warn(msg, stacklevel=2)
if pmf_gen is not None:
# use the sphere data from the pmf_gen
sphere = pmf_gen.get_sphere()
inv_affine = np.linalg.inv(affine)
vox_positions = np.dot(positions, inv_affine[:3, :3].T.copy())
vox_positions += inv_affine[:3, 3]
peaks_arr = np.zeros((len(positions), npeaks, 3))
if vox_positions.dtype not in [np.float64, float]:
vox_positions = vox_positions.astype(float)
for i, s in enumerate(vox_positions):
if pmf_gen:
odf = pmf_gen.get_pmf(s)
else:
odf = trilinear_interpolate4d(odfs, s)
peaks, _, _ = peak_directions(
odf,
sphere,
relative_peak_threshold=relative_peak_threshold,
min_separation_angle=min_separation_angle,
is_symmetric=is_symmetric,
)
nbr_peaks = min(npeaks, peaks.shape[0])
peaks_arr[i, :nbr_peaks, :] = peaks[:nbr_peaks, :]
return peaks_arr