Source code for dipy.reconst.multi_voxel

"""Tools to easily make multi voxel models"""

from functools import partial
import multiprocessing

import numpy as np
from tqdm import tqdm

from dipy.core.ndindex import ndindex
from dipy.reconst.base import ReconstFit
from dipy.reconst.quick_squash import quick_squash as _squash
from dipy.utils.parallel import paramap


def _parallel_fit_worker(vox_data, single_voxel_fit, **kwargs):
    """
    Works on a chunk of voxel data to create a list of
    single voxel fits.

    Parameters
    ----------
    vox_data : ndarray, shape (n_voxels, ...)
        The data to fit.

    single_voxel_fit : callable
        The fit function to use on each voxel.
    """
    vox_weights = kwargs.pop("weights", None)
    if type(vox_weights) is np.ndarray:
        return [
            single_voxel_fit(data, **(dict({"weights": weights}, **kwargs)))
            for data, weights in zip(vox_data, vox_weights)
        ]
    else:
        return [single_voxel_fit(data, **kwargs) for data in vox_data]


[docs] def multi_voxel_fit(single_voxel_fit): """Method decorator to turn a single voxel model fit definition into a multi voxel model fit definition """ def new_fit(self, data, *, mask=None, **kwargs): """Fit method for every voxel in data""" # If only one voxel just return a standard fit, passing through # the functions key-word arguments (no mask needed). if data.ndim == 1: svf = single_voxel_fit(self, data, **kwargs) # If fit method does not return extra, cannot return extra if isinstance(svf, tuple): svf, extra = svf return svf, extra else: return svf # Make a mask if mask is None if mask is None: mask = np.ones(data.shape[:-1], bool) # Check the shape of the mask if mask is not None elif mask.shape != data.shape[:-1]: raise ValueError("mask and data shape do not match") # Get weights from kwargs if provided weights = kwargs["weights"] if "weights" in kwargs else None weights_is_array = True if type(weights) is np.ndarray else False # Fit data where mask is True fit_array = np.empty(data.shape[:-1], dtype=object) return_extra = False # Default to serial execution: engine = kwargs.get("engine", "serial") if engine == "serial": extra_list = [] bar = tqdm( total=np.sum(mask), position=0, disable=kwargs.get("verbose", True) ) bar.set_description("Fitting reconstruction model using serial execution") for ijk in ndindex(data.shape[:-1]): if mask[ijk]: if weights_is_array: kwargs["weights"] = weights[ijk] svf = single_voxel_fit(self, data[ijk], **kwargs) # Not all fit methods return extra, handle this here if isinstance(svf, tuple): fit_array[ijk], extra = svf return_extra = True else: fit_array[ijk], extra = svf, None extra_list.append(extra) bar.update() bar.close() else: data_to_fit = data[np.where(mask)] if weights_is_array: weights_to_fit = weights[np.where(mask)] single_voxel_with_self = partial(single_voxel_fit, self) n_jobs = kwargs.get("n_jobs", multiprocessing.cpu_count() - 1) vox_per_chunk = kwargs.get( "vox_per_chunk", np.max([data_to_fit.shape[0] // n_jobs, 1]) ) chunks = [ data_to_fit[ii : ii + vox_per_chunk] for ii in range(0, data_to_fit.shape[0], vox_per_chunk) ] # func_kwargs : dict or sequence, optional # Keyword arguments to `func` or sequence of keyword arguments # to `func`: one item for each item in the input list. kwargs_chunks = [] for ii in range(0, data_to_fit.shape[0], vox_per_chunk): kw = kwargs.copy() if weights_is_array: kw["weights"] = weights_to_fit[ii : ii + vox_per_chunk] kwargs_chunks.append(kw) parallel_kwargs = {} for kk in ["n_jobs", "vox_per_chunk", "engine", "verbose"]: if kk in kwargs: parallel_kwargs[kk] = kwargs[kk] mvf = paramap( _parallel_fit_worker, chunks, func_args=[single_voxel_with_self], func_kwargs=kwargs_chunks, **parallel_kwargs, ) if isinstance(mvf[0][0], tuple): tmp_fit_array = np.concatenate( [[svf[0] for svf in mvf_ch] for mvf_ch in mvf] ) tmp_extra = np.concatenate( [[svf[1] for svf in mvf_ch] for mvf_ch in mvf] ).tolist() fit_array[np.where(mask)], extra_list = tmp_fit_array, tmp_extra return_extra = True else: tmp_fit_array = np.concatenate(mvf) fit_array[np.where(mask)], extra_list = tmp_fit_array, None # Redefine extra to be a single dictionary if return_extra: if extra_list[0] is not None: extra_mask = { key: np.vstack([e[key] for e in extra_list]) for key in extra_list[0] } extra = {} for key in extra_mask: extra[key] = np.zeros(data.shape) extra[key][mask == 1] = extra_mask[key] else: extra = None # If fit method does not return extra, assume we cannot return extra if return_extra: return MultiVoxelFit(self, fit_array, mask), extra else: return MultiVoxelFit(self, fit_array, mask) return new_fit
[docs] class MultiVoxelFit(ReconstFit): """Holds an array of fits and allows access to their attributes and methods""" def __init__(self, model, fit_array, mask): self.model = model self.fit_array = fit_array self.mask = mask @property def shape(self): return self.fit_array.shape def __getattr__(self, attr): result = CallableArray(self.fit_array.shape, dtype=object) for ijk in ndindex(result.shape): if self.mask[ijk]: result[ijk] = getattr(self.fit_array[ijk], attr) return _squash(result, self.mask) def __getitem__(self, index): item = self.fit_array[index] if isinstance(item, np.ndarray): return MultiVoxelFit(self.model, item, self.mask[index]) else: return item
[docs] def predict(self, *args, **kwargs): """ Predict for the multi-voxel object using each single-object's prediction API, with S0 provided from an array. """ S0 = kwargs.get("S0", np.ones(self.fit_array.shape)) idx = ndindex(self.fit_array.shape) ijk = next(idx) def gimme_S0(S0, ijk): if isinstance(S0, np.ndarray): return S0[ijk] else: return S0 kwargs["S0"] = gimme_S0(S0, ijk) # If we have a mask, we might have some Nones up front, skip those: while self.fit_array[ijk] is None: ijk = next(idx) if not hasattr(self.fit_array[ijk], "predict"): msg = "This model does not have prediction implemented yet" raise NotImplementedError(msg) first_pred = self.fit_array[ijk].predict(*args, **kwargs) result = np.zeros(self.fit_array.shape + (first_pred.shape[-1],)) result[ijk] = first_pred for ijk in idx: kwargs["S0"] = gimme_S0(S0, ijk) # If it's masked, we predict a 0: if self.fit_array[ijk] is None: result[ijk] *= 0 else: result[ijk] = self.fit_array[ijk].predict(*args, **kwargs) return result
[docs] class CallableArray(np.ndarray): """An array which can be called like a function""" def __call__(self, *args, **kwargs): result = np.empty(self.shape, dtype=object) for ijk in ndindex(self.shape): item = self[ijk] if item is not None: result[ijk] = item(*args, **kwargs) return _squash(result)