"""Tools to easily make multi voxel models"""
from functools import partial
import multiprocessing
import numpy as np
from tqdm import tqdm
from dipy.core.ndindex import ndindex
from dipy.reconst.base import ReconstFit
from dipy.reconst.quick_squash import quick_squash as _squash
from dipy.utils.parallel import paramap
def _parallel_fit_worker(vox_data, single_voxel_fit, **kwargs):
    """
    Works on a chunk of voxel data to create a list of
    single voxel fits.
    Parameters
    ----------
    vox_data : ndarray, shape (n_voxels, ...)
        The data to fit.
    single_voxel_fit : callable
        The fit function to use on each voxel.
    """
    vox_weights = kwargs.pop("weights", None)
    if type(vox_weights) is np.ndarray:
        return [
            single_voxel_fit(data, **(dict({"weights": weights}, **kwargs)))
            for data, weights in zip(vox_data, vox_weights)
        ]
    else:
        return [single_voxel_fit(data, **kwargs) for data in vox_data]
[docs]
def multi_voxel_fit(single_voxel_fit):
    """Method decorator to turn a single voxel model fit
    definition into a multi voxel model fit definition
    """
    def new_fit(self, data, *, mask=None, **kwargs):
        """Fit method for every voxel in data"""
        # If only one voxel just return a standard fit, passing through
        # the functions key-word arguments (no mask needed).
        if data.ndim == 1:
            svf = single_voxel_fit(self, data, **kwargs)
            # If fit method does not return extra, cannot return extra
            if isinstance(svf, tuple):
                svf, extra = svf
                return svf, extra
            else:
                return svf
        # Make a mask if mask is None
        if mask is None:
            mask = np.ones(data.shape[:-1], bool)
        # Check the shape of the mask if mask is not None
        elif mask.shape != data.shape[:-1]:
            raise ValueError("mask and data shape do not match")
        # Get weights from kwargs if provided
        weights = kwargs["weights"] if "weights" in kwargs else None
        weights_is_array = True if type(weights) is np.ndarray else False
        # Fit data where mask is True
        fit_array = np.empty(data.shape[:-1], dtype=object)
        return_extra = False
        # Default to serial execution:
        engine = kwargs.get("engine", "serial")
        if engine == "serial":
            extra_list = []
            bar = tqdm(
                total=np.sum(mask), position=0, disable=kwargs.get("verbose", True)
            )
            bar.set_description("Fitting reconstruction model using serial execution")
            for ijk in ndindex(data.shape[:-1]):
                if mask[ijk]:
                    if weights_is_array:
                        kwargs["weights"] = weights[ijk]
                    svf = single_voxel_fit(self, data[ijk], **kwargs)
                    # Not all fit methods return extra, handle this here
                    if isinstance(svf, tuple):
                        fit_array[ijk], extra = svf
                        return_extra = True
                    else:
                        fit_array[ijk], extra = svf, None
                    extra_list.append(extra)
                bar.update()
            bar.close()
        else:
            data_to_fit = data[np.where(mask)]
            if weights_is_array:
                weights_to_fit = weights[np.where(mask)]
            single_voxel_with_self = partial(single_voxel_fit, self)
            n_jobs = kwargs.get("n_jobs", multiprocessing.cpu_count() - 1)
            vox_per_chunk = kwargs.get(
                "vox_per_chunk", np.max([data_to_fit.shape[0] // n_jobs, 1])
            )
            chunks = [
                data_to_fit[ii : ii + vox_per_chunk]
                for ii in range(0, data_to_fit.shape[0], vox_per_chunk)
            ]
            # func_kwargs : dict or sequence, optional
            # Keyword arguments to `func` or sequence of keyword arguments
            # to `func`: one item for each item in the input list.
            kwargs_chunks = []
            for ii in range(0, data_to_fit.shape[0], vox_per_chunk):
                kw = kwargs.copy()
                if weights_is_array:
                    kw["weights"] = weights_to_fit[ii : ii + vox_per_chunk]
                kwargs_chunks.append(kw)
            parallel_kwargs = {}
            for kk in ["n_jobs", "vox_per_chunk", "engine", "verbose"]:
                if kk in kwargs:
                    parallel_kwargs[kk] = kwargs[kk]
            mvf = paramap(
                _parallel_fit_worker,
                chunks,
                func_args=[single_voxel_with_self],
                func_kwargs=kwargs_chunks,
                **parallel_kwargs,
            )
            if isinstance(mvf[0][0], tuple):
                tmp_fit_array = np.concatenate(
                    [[svf[0] for svf in mvf_ch] for mvf_ch in mvf]
                )
                tmp_extra = np.concatenate(
                    [[svf[1] for svf in mvf_ch] for mvf_ch in mvf]
                ).tolist()
                fit_array[np.where(mask)], extra_list = tmp_fit_array, tmp_extra
                return_extra = True
            else:
                tmp_fit_array = np.concatenate(mvf)
                fit_array[np.where(mask)], extra_list = tmp_fit_array, None
        # Redefine extra to be a single dictionary
        if return_extra:
            if extra_list[0] is not None:
                extra_mask = {
                    key: np.vstack([e[key] for e in extra_list])
                    for key in extra_list[0]
                }
                extra = {}
                for key in extra_mask:
                    extra[key] = np.zeros(data.shape)
                    extra[key][mask == 1] = extra_mask[key]
            else:
                extra = None
        # If fit method does not return extra, assume we cannot return extra
        if return_extra:
            return MultiVoxelFit(self, fit_array, mask), extra
        else:
            return MultiVoxelFit(self, fit_array, mask)
    return new_fit 
[docs]
class MultiVoxelFit(ReconstFit):
    """Holds an array of fits and allows access to their attributes and
    methods"""
    def __init__(self, model, fit_array, mask):
        self.model = model
        self.fit_array = fit_array
        self.mask = mask
    @property
    def shape(self):
        return self.fit_array.shape
    def __getattr__(self, attr):
        result = CallableArray(self.fit_array.shape, dtype=object)
        for ijk in ndindex(result.shape):
            if self.mask[ijk]:
                result[ijk] = getattr(self.fit_array[ijk], attr)
        return _squash(result, self.mask)
    def __getitem__(self, index):
        item = self.fit_array[index]
        if isinstance(item, np.ndarray):
            return MultiVoxelFit(self.model, item, self.mask[index])
        else:
            return item
[docs]
    def predict(self, *args, **kwargs):
        """
        Predict for the multi-voxel object using each single-object's
        prediction API, with S0 provided from an array.
        """
        S0 = kwargs.get("S0", np.ones(self.fit_array.shape))
        idx = ndindex(self.fit_array.shape)
        ijk = next(idx)
        def gimme_S0(S0, ijk):
            if isinstance(S0, np.ndarray):
                return S0[ijk]
            else:
                return S0
        kwargs["S0"] = gimme_S0(S0, ijk)
        # If we have a mask, we might have some Nones up front, skip those:
        while self.fit_array[ijk] is None:
            ijk = next(idx)
        if not hasattr(self.fit_array[ijk], "predict"):
            msg = "This model does not have prediction implemented yet"
            raise NotImplementedError(msg)
        first_pred = self.fit_array[ijk].predict(*args, **kwargs)
        result = np.zeros(self.fit_array.shape + (first_pred.shape[-1],))
        result[ijk] = first_pred
        for ijk in idx:
            kwargs["S0"] = gimme_S0(S0, ijk)
            # If it's masked, we predict a 0:
            if self.fit_array[ijk] is None:
                result[ijk] *= 0
            else:
                result[ijk] = self.fit_array[ijk].predict(*args, **kwargs)
        return result 
 
[docs]
class CallableArray(np.ndarray):
    """An array which can be called like a function"""
    def __call__(self, *args, **kwargs):
        result = np.empty(self.shape, dtype=object)
        for ijk in ndindex(self.shape):
            item = self[ijk]
            if item is not None:
                result[ijk] = item(*args, **kwargs)
        return _squash(result)