from itertools import repeat
import multiprocessing as mp
from os import path
import tempfile
import numpy as np
import scipy.optimize as opt
from dipy.core.interpolation import trilinear_interpolate4d
from dipy.core.ndindex import ndindex
from dipy.core.sphere import Sphere
from dipy.data import default_sphere
from dipy.reconst.eudx_direction_getter import EuDXDirectionGetter
from dipy.reconst.odf import gfa
from dipy.reconst.recspeed import (
    local_maxima,
    remove_similar_vertices,
    search_descending,
)
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.multiproc import determine_num_processes
[docs]
@warning_for_keywords()
def peak_directions_nl(
    sphere_eval,
    *,
    relative_peak_threshold=0.25,
    min_separation_angle=25,
    sphere=default_sphere,
    xtol=1e-7,
):
    """Non Linear Direction Finder.
    Parameters
    ----------
    sphere_eval : callable
        A function which can be evaluated on a sphere.
    relative_peak_threshold : float
        Only return peaks greater than ``relative_peak_threshold * m`` where m
        is the largest peak.
    min_separation_angle : float in [0, 90]
        The minimum distance between directions. If two peaks are too close
        only the larger of the two is returned.
    sphere : Sphere
        A discrete Sphere. The points on the sphere will be used for initial
        estimate of maximums.
    xtol : float
        Relative tolerance for optimization.
    Returns
    -------
    directions : array (N, 3)
        Points on the sphere corresponding to N local maxima on the sphere.
    values : array (N,)
        Value of sphere_eval at each point on directions.
    """
    # Find discrete peaks for use as seeds in non-linear search
    discrete_values = sphere_eval(sphere)
    values, indices = local_maxima(discrete_values, sphere.edges)
    seeds = np.column_stack([sphere.theta[indices], sphere.phi[indices]])
    # Helper function
    def _helper(x):
        sphere = Sphere(theta=x[0], phi=x[1])
        return -sphere_eval(sphere)
    # Non-linear search
    num_seeds = len(seeds)
    theta = np.empty(num_seeds)
    phi = np.empty(num_seeds)
    for i in range(num_seeds):
        peak = opt.fmin(_helper, seeds[i], xtol=xtol, disp=False)
        theta[i], phi[i] = peak
    # Evaluate on new-found peaks
    small_sphere = Sphere(theta=theta, phi=phi)
    values = sphere_eval(small_sphere)
    # Sort in descending order
    order = values.argsort()[::-1]
    values = values[order]
    directions = small_sphere.vertices[order]
    # Remove directions that are too small
    n = search_descending(values, relative_peak_threshold)
    directions = directions[:n]
    # Remove peaks too close to each-other
    directions, idx = remove_similar_vertices(
        directions, min_separation_angle, return_index=True
    )
    values = values[idx]
    return directions, values 
[docs]
@warning_for_keywords()
def peak_directions(
    odf,
    sphere,
    *,
    relative_peak_threshold=0.5,
    min_separation_angle=25,
    is_symmetric=True,
):
    """Get the directions of odf peaks.
    Peaks are defined as points on the odf that are greater than at least one
    neighbor and greater than or equal to all neighbors. Peaks are sorted in
    descending order by their values then filtered based on their relative size
    and spacing on the sphere. An odf may have 0 peaks, for example if the odf
    is perfectly isotropic.
    Parameters
    ----------
    odf : 1d ndarray
        The odf function evaluated on the vertices of `sphere`
    sphere : Sphere
        The Sphere providing discrete directions for evaluation.
    relative_peak_threshold : float in [0., 1.]
        Only peaks greater than ``min + relative_peak_threshold * scale`` are
        kept, where ``min = max(0, odf.min())`` and
        ``scale = odf.max() - min``.
    min_separation_angle : float in [0, 90]
        The minimum distance between directions. If two peaks are too close
        only the larger of the two is returned.
    is_symmetric : bool, optional
        If True, v is considered equal to -v.
    Returns
    -------
    directions : (N, 3) ndarray
        N vertices for sphere, one for each peak
    values : (N,) ndarray
        peak values
    indices : (N,) ndarray
        peak indices of the directions on the sphere
    Notes
    -----
    If the odf has any negative values, they will be clipped to zeros.
    """
    values, indices = local_maxima(odf, sphere.edges)
    # If there is only one peak return
    n = len(values)
    if n == 0 or (values[0] < 0.0):
        return np.zeros((0, 3)), np.zeros(0), np.zeros(0, dtype=int)
    elif n == 1:
        return sphere.vertices[indices], values, indices
    odf_min = np.min(odf)
    odf_min = max(odf_min, 0.0)
    # because of the relative threshold this algorithm will give the same peaks
    # as if we divide (values - odf_min) with (odf_max - odf_min) or not so
    # here we skip the division to increase speed
    values_norm = values - odf_min
    # Remove small peaks
    n = search_descending(values_norm, relative_peak_threshold)
    indices = indices[:n]
    directions = sphere.vertices[indices]
    # Remove peaks too close together
    directions, uniq = remove_similar_vertices(
        directions,
        min_separation_angle,
        return_index=True,
        remove_antipodal=is_symmetric,
    )
    values = values[uniq]
    indices = indices[uniq]
    return directions, values, indices 
def _pam_from_attrs(
    klass, sphere, peak_indices, peak_values, peak_dirs, gfa, qa, shm_coeff, B, odf
):
    """
    Construct PeaksAndMetrics object (or subclass) from its attributes.
    This is also useful for pickling/unpickling of these objects (see also
    :func:`__reduce__` below).
    Parameters
    ----------
    klass : class
        The class of object to be created.
    sphere : `Sphere` class instance.
        Sphere for discretization.
    peak_indices : ndarray
        Indices (in sphere vertices) of the peaks in each voxel.
    peak_values : ndarray
        The value of the peaks.
    peak_dirs : ndarray
        The direction of each peak.
    gfa : ndarray
        The Generalized Fractional Anisotropy in each voxel.
    qa : ndarray
        Quantitative Anisotropy in each voxel.
    shm_coeff : ndarray
        The coefficients of the spherical harmonic basis for the ODF in
        each voxel.
    B : ndarray
        The spherical harmonic matrix, for multiplication with the
        coefficients.
    odf : ndarray
        The orientation distribution function on the sphere in each voxel.
    Returns
    -------
    pam : Instance of the class `klass`.
    """
    this_pam = klass()
    this_pam.sphere = sphere
    this_pam.peak_dirs = peak_dirs
    this_pam.peak_values = peak_values
    this_pam.peak_indices = peak_indices
    this_pam.gfa = gfa
    this_pam.qa = qa
    this_pam.shm_coeff = shm_coeff
    this_pam.B = B
    this_pam.odf = odf
    return this_pam
[docs]
class PeaksAndMetrics(EuDXDirectionGetter):
    def __reduce__(self):
        return _pam_from_attrs, (
            self.__class__,
            self.sphere,
            self.peak_indices,
            self.peak_values,
            self.peak_dirs,
            self.gfa,
            self.qa,
            self.shm_coeff,
            self.B,
            self.odf,
        ) 
def _peaks_from_model_parallel(
    model,
    data,
    sphere,
    relative_peak_threshold,
    min_separation_angle,
    mask,
    return_odf,
    return_sh,
    gfa_thr,
    normalize_peaks,
    sh_order,
    sh_basis_type,
    legacy,
    npeaks,
    B,
    invB,
    num_processes,
):
    shape = list(data.shape)
    data = np.reshape(data, (-1, shape[-1]))
    n = data.shape[0]
    nbr_chunks = num_processes**2
    chunk_size = int(np.ceil(n / nbr_chunks))
    indices = list(
        zip(
            np.arange(0, n, chunk_size),
            np.arange(0, n, chunk_size) + chunk_size,
        )
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        data_file_name = path.join(tmpdir, "data.npy")
        np.save(data_file_name, data)
        if mask is not None:
            mask = mask.flatten()
            mask_file_name = path.join(tmpdir, "mask.npy")
            np.save(mask_file_name, mask)
        else:
            mask_file_name = None
        mp.set_start_method("spawn", force=True)
        pool = mp.Pool(num_processes)
        pam_res = pool.map(
            _peaks_from_model_parallel_sub,
            zip(
                repeat((data_file_name, mask_file_name)),
                indices,
                repeat(model),
                repeat(sphere),
                repeat(relative_peak_threshold),
                repeat(min_separation_angle),
                repeat(return_odf),
                repeat(return_sh),
                repeat(gfa_thr),
                repeat(normalize_peaks),
                repeat(sh_order),
                repeat(sh_basis_type),
                repeat(legacy),
                repeat(npeaks),
                repeat(B),
                repeat(invB),
            ),
        )
        pool.close()
        pam = PeaksAndMetrics()
        pam.sphere = sphere
        # use memmap to reduce the memory usage
        pam.gfa = np.memmap(
            path.join(tmpdir, "gfa.npy"),
            dtype=pam_res[0].gfa.dtype,
            mode="w+",
            shape=(data.shape[0]),
        )
        pam.peak_dirs = np.memmap(
            path.join(tmpdir, "peak_dirs.npy"),
            dtype=pam_res[0].peak_dirs.dtype,
            mode="w+",
            shape=(data.shape[0], npeaks, 3),
        )
        pam.peak_values = np.memmap(
            path.join(tmpdir, "peak_values.npy"),
            dtype=pam_res[0].peak_values.dtype,
            mode="w+",
            shape=(data.shape[0], npeaks),
        )
        pam.peak_indices = np.memmap(
            path.join(tmpdir, "peak_indices.npy"),
            dtype=pam_res[0].peak_indices.dtype,
            mode="w+",
            shape=(data.shape[0], npeaks),
        )
        pam.qa = np.memmap(
            path.join(tmpdir, "qa.npy"),
            dtype=pam_res[0].qa.dtype,
            mode="w+",
            shape=(data.shape[0], npeaks),
        )
        if return_sh:
            nbr_shm_coeff = (sh_order + 2) * (sh_order + 1) // 2
            pam.shm_coeff = np.memmap(
                path.join(tmpdir, "shm.npy"),
                dtype=pam_res[0].shm_coeff.dtype,
                mode="w+",
                shape=(data.shape[0], nbr_shm_coeff),
            )
            pam.B = pam_res[0].B
        else:
            pam.shm_coeff = None
            pam.invB = None
        if return_odf:
            pam.odf = np.memmap(
                path.join(tmpdir, "odf.npy"),
                dtype=pam_res[0].odf.dtype,
                mode="w+",
                shape=(data.shape[0], len(sphere.vertices)),
            )
        else:
            pam.odf = None
        # copy subprocesses pam to a single pam (memmaps)
        for i, (start_pos, end_pos) in enumerate(indices):
            pam.gfa[start_pos:end_pos] = pam_res[i].gfa
            pam.peak_dirs[start_pos:end_pos] = pam_res[i].peak_dirs
            pam.peak_values[start_pos:end_pos] = pam_res[i].peak_values
            pam.peak_indices[start_pos:end_pos] = pam_res[i].peak_indices
            pam.qa[start_pos:end_pos] = pam_res[i].qa
            if return_sh:
                pam.shm_coeff[start_pos:end_pos] = pam_res[i].shm_coeff
            if return_odf:
                pam.odf[start_pos:end_pos] = pam_res[i].odf
        # load memmaps to arrays and reshape the metric
        shape[-1] = -1
        pam.gfa = np.reshape(np.array(pam.gfa), shape[:-1])
        pam.peak_dirs = np.reshape(np.array(pam.peak_dirs), shape + [3])
        pam.peak_values = np.reshape(np.array(pam.peak_values), shape)
        pam.peak_indices = np.reshape(np.array(pam.peak_indices), shape)
        pam.qa = np.reshape(np.array(pam.qa), shape)
        if return_sh:
            pam.shm_coeff = np.reshape(np.array(pam.shm_coeff), shape)
        if return_odf:
            pam.odf = np.reshape(np.array(pam.odf), shape)
        # Make sure all worker processes have exited before leaving context
        # manager in order to prevent temporary file deletion errors in windows
        pool.join()
    return pam
def _peaks_from_model_parallel_sub(args):
    (data_file_name, mask_file_name) = args[0]
    (start_pos, end_pos) = args[1]
    model = args[2]
    sphere = args[3]
    relative_peak_threshold = args[4]
    min_separation_angle = args[5]
    return_odf = args[6]
    return_sh = args[7]
    gfa_thr = args[8]
    normalize_peaks = args[9]
    sh_order = args[10]
    sh_basis_type = args[11]
    legacy = args[12]
    npeaks = args[13]
    B = args[14]
    invB = args[15]
    data = np.load(data_file_name, mmap_mode="r")[start_pos:end_pos]
    if mask_file_name is not None:
        mask = np.load(mask_file_name, mmap_mode="r")[start_pos:end_pos]
    else:
        mask = None
    return peaks_from_model(
        model,
        data,
        sphere,
        relative_peak_threshold,
        min_separation_angle,
        mask=mask,
        return_odf=return_odf,
        return_sh=return_sh,
        gfa_thr=gfa_thr,
        normalize_peaks=normalize_peaks,
        sh_order_max=sh_order,
        sh_basis_type=sh_basis_type,
        legacy=legacy,
        npeaks=npeaks,
        B=B,
        invB=invB,
        parallel=False,
        num_processes=None,
    )
[docs]
@deprecated_params("sh_order", new_name="sh_order_max", since="1.9", until="2.0")
@warning_for_keywords()
def peaks_from_model(
    model,
    data,
    sphere,
    relative_peak_threshold,
    min_separation_angle,
    *,
    mask=None,
    return_odf=False,
    return_sh=True,
    gfa_thr=0,
    normalize_peaks=False,
    sh_order_max=8,
    sh_basis_type=None,
    legacy=True,
    npeaks=5,
    B=None,
    invB=None,
    parallel=False,
    num_processes=None,
):
    """Fit the model to data and computes peaks and metrics
    Parameters
    ----------
    model : a model instance
        `model` will be used to fit the data.
    data : ndarray
        Diffusion data.
    sphere : Sphere
        The Sphere providing discrete directions for evaluation.
    relative_peak_threshold : float
        Only return peaks greater than ``relative_peak_threshold * m`` where m
        is the largest peak.
    min_separation_angle : float in [0, 90] The minimum distance between
        directions. If two peaks are too close only the larger of the two is
        returned.
    mask : array, optional
        If `mask` is provided, voxels that are False in `mask` are skipped and
        no peaks are returned.
    return_odf : bool
        If True, the odfs are returned.
    return_sh : bool
        If True, the odf as spherical harmonics coefficients is returned
    gfa_thr : float
        Voxels with gfa less than `gfa_thr` are skipped, no peaks are returned.
    normalize_peaks : bool
        If true, all peak values are calculated relative to `max(odf)`.
    sh_order_max : int, optional
        Maximum SH order (l) in the SH fit.  For `sh_order_max`, there
        will be
        ``(sh_order_max + 1) * (sh_order_max + 2) / 2``
        SH coefficients (default 8).
    sh_basis_type : {None, 'tournier07', 'descoteaux07'}
        ``None`` for the default DIPY basis,
        ``tournier07`` for the Tournier 2007 :footcite:p:Tournier2007` basis,
        and ``descoteaux07`` for the Descoteaux 2007 :footcite:p:Descoteaux2007`
        basis
        (``None`` defaults to ``descoteaux07``).
    legacy: bool, optional
        True to use a legacy basis definition for backward compatibility
        with previous ``tournier07`` and ``descoteaux07`` implementations.
    npeaks : int
        Maximum number of peaks found (default 5 peaks).
    B : ndarray, optional
        Matrix that transforms spherical harmonics to spherical function
        ``sf = np.dot(sh, B)``.
    invB : ndarray, optional
        Inverse of B.
    parallel: bool
        If True, use multiprocessing to compute peaks and metric
        (default False). Temporary files are saved in the default temporary
        directory of the system. It can be changed using ``import tempfile``
        and ``tempfile.tempdir = '/path/to/tempdir'``.
    num_processes: int, optional
        If `parallel` is True, the number of subprocesses to use
        (default multiprocessing.cpu_count()). If < 0 the maximal number of
        cores minus ``num_processes + 1`` is used (enter -1 to use as many
        cores as possible). 0 raises an error.
    Returns
    -------
    pam : PeaksAndMetrics
        An object with ``gfa``, ``peak_directions``, ``peak_values``,
        ``peak_indices``, ``odf``, ``shm_coeffs`` as attributes
    References
    ----------
    .. footbibliography::
    """
    if return_sh and (B is None or invB is None):
        B, invB = sh_to_sf_matrix(
            sphere,
            sh_order_max=sh_order_max,
            basis_type=sh_basis_type,
            return_inv=True,
            legacy=legacy,
        )
    num_processes = determine_num_processes(num_processes)
    if parallel and num_processes > 1:
        # It is mandatory to provide B and invB to the parallel function.
        # Otherwise, a call to np.linalg.pinv is made in a subprocess and
        # makes it timeout on some system.
        # see https://github.com/dipy/dipy/issues/253 for details
        return _peaks_from_model_parallel(
            model,
            data,
            sphere,
            relative_peak_threshold,
            min_separation_angle,
            mask,
            return_odf,
            return_sh,
            gfa_thr,
            normalize_peaks,
            sh_order_max,
            sh_basis_type,
            legacy,
            npeaks,
            B,
            invB,
            num_processes,
        )
    shape = data.shape[:-1]
    if mask is None:
        mask = np.ones(shape, dtype="bool")
    else:
        if mask.shape != shape:
            raise ValueError("Mask is not the same shape as data.")
    gfa_array = np.zeros(shape)
    qa_array = np.zeros((shape + (npeaks,)))
    peak_dirs = np.zeros((shape + (npeaks, 3)))
    peak_values = np.zeros((shape + (npeaks,)))
    peak_indices = np.zeros((shape + (npeaks,)), dtype=np.int32)
    peak_indices.fill(-1)
    if return_sh:
        n_shm_coeff = (sh_order_max + 2) * (sh_order_max + 1) // 2
        shm_coeff = np.zeros((shape + (n_shm_coeff,)))
    if return_odf:
        odf_array = np.zeros((shape + (len(sphere.vertices),)))
    global_max = -np.inf
    for idx in ndindex(shape):
        if not mask[idx]:
            continue
        odf = model.fit(data[idx]).odf(sphere=sphere)
        if return_sh:
            shm_coeff[idx] = np.dot(odf, invB)
        if return_odf:
            odf_array[idx] = odf
        gfa_array[idx] = gfa(odf)
        if gfa_array[idx] < gfa_thr:
            global_max = max(global_max, odf.max())
            continue
        # Get peaks of odf
        direction, pk, ind = peak_directions(
            odf,
            sphere,
            relative_peak_threshold=relative_peak_threshold,
            min_separation_angle=min_separation_angle,
        )
        # Calculate peak metrics
        if pk.shape[0] != 0:
            global_max = max(global_max, pk[0])
            n = min(npeaks, pk.shape[0])
            qa_array[idx][:n] = pk[:n] - odf.min()
            peak_dirs[idx][:n] = direction[:n]
            peak_indices[idx][:n] = ind[:n]
            peak_values[idx][:n] = pk[:n]
            if normalize_peaks:
                peak_values[idx][:n] /= pk[0]
                peak_dirs[idx] *= peak_values[idx][:, None]
    qa_array /= global_max
    return _pam_from_attrs(
        PeaksAndMetrics,
        sphere,
        peak_indices,
        peak_values,
        peak_dirs,
        gfa_array,
        qa_array,
        shm_coeff if return_sh else None,
        B if return_sh else None,
        odf_array if return_odf else None,
    ) 
[docs]
def reshape_peaks_for_visualization(peaks):
    """Reshape peaks for visualization.
    Reshape and convert to float32 a set of peaks for visualisation with mrtrix
    or the fibernavigator.
    Parameters
    ----------
    peaks: nd array (..., N, 3) or PeaksAndMetrics object
        The peaks to be reshaped and converted to float32.
    Returns
    -------
    peaks : nd array (..., 3*N)
    """
    if isinstance(peaks, PeaksAndMetrics):
        peaks = peaks.peak_dirs
    return peaks.reshape(np.append(peaks.shape[:-2], -1)).astype("float32") 
[docs]
def peaks_from_positions(
    positions,
    odfs,
    sphere,
    affine,
    *,
    relative_peak_threshold=0.5,
    min_separation_angle=25,
    is_symmetric=True,
    npeaks=5,
):
    """
    Extract the peaks at each positions.
    Parameters
    ----------
    position : array, (N, 3)
        World coordinates of the N positions.
    odfs : array, (X, Y, Z, M)
        Orientation distribution function (spherical function) represented
        on a sphere of M points.
    sphere : Sphere
        A discrete Sphere. The M points on the sphere correspond to the points
        of the odfs.
    affine : array (4, 4)
        The mapping between voxel indices and the point space for positions.
    relative_peak_threshold : float, optional
        Only peaks greater than ``min + relative_peak_threshold * scale`` are
        kept, where ``min = max(0, odf.min())`` and
        ``scale = odf.max() - min``. The ``relative_peak_threshold`` should
        be in the range [0, 1].
    min_separation_angle : float, optional
        The minimum distance between directions. If two peaks are too close
        only the larger of the two is returned. The ``min_separation_angle``
        should be in the range [0, 90].
    is_symmetric : bool, optional
        If True, v is considered equal to -v.
    npeaks : int, optional
        The maximum number of peaks to extract at from each position.
    Returns
    -------
    peaks_arr : array (N, npeaks, 3)
    """
    inv_affine = np.linalg.inv(affine)
    vox_positions = np.dot(positions, inv_affine[:3, :3].T.copy())
    vox_positions += inv_affine[:3, 3]
    peaks_arr = np.zeros((len(positions), npeaks, 3))
    if vox_positions.dtype not in [np.float64, float]:
        vox_positions = vox_positions.astype(float)
    for i, s in enumerate(vox_positions):
        odf = trilinear_interpolate4d(odfs, s)
        peaks, _, _ = peak_directions(
            odf,
            sphere,
            relative_peak_threshold=relative_peak_threshold,
            min_separation_angle=min_separation_angle,
            is_symmetric=is_symmetric,
        )
        nbr_peaks = min(npeaks, peaks.shape[0])
        peaks_arr[i, :nbr_peaks, :] = peaks[:nbr_peaks, :]
    return peaks_arr