Source code for dipy.reconst.multi_voxel

"""Tools to easily make multi voxel models"""

from functools import partial
import multiprocessing

import numpy as np
from tqdm import tqdm

from dipy.core.ndindex import ndindex
from dipy.reconst.base import ReconstFit
from dipy.reconst.quick_squash import quick_squash as _squash
from dipy.utils.multiproc import determine_num_processes
from dipy.utils.parallel import auto_ray_chunk_size, paramap


def _assemble_results(params_list, *, fit_class):
    """Build a flat array of fit objects from raw per-chunk parameter dicts.

    Generic assembly for any batched model that declares a ``_fit_class``
    attribute.  Concatenates numpy arrays across chunks, then constructs
    fit objects once in the caller process — avoiding expensive
    Python-object serialisation through the Ray object store.

    Parameters
    ----------
    params_list : list of dict
        Each element is a ``dict`` of numpy arrays returned by ``fit``
        when called with ``_raw=True``.  All arrays within a dict share
        the same first dimension (the chunk size).
    fit_class : type
        Fit class to instantiate for each voxel as ``fit_class(None, params)``.

    Returns
    -------
    fits : ndarray of object, shape (total_voxels,)
    """
    keys = list(params_list[0].keys())
    merged = {}
    for k in keys:
        arrays = [p[k] for p in params_list]
        merged[k] = None if arrays[0] is None else np.concatenate(arrays, axis=0)

    n_vox = next(v.shape[0] for v in merged.values() if v is not None)
    fits = np.empty(n_vox, dtype=object)
    for i in range(n_vox):
        p = {}
        for k in keys:
            val = merged[k]
            if val is None:
                p[k] = None
            else:
                v = val[i]
                if isinstance(v, np.ndarray) and v.ndim == 0:
                    p[k] = float(v)
                elif isinstance(v, (np.floating, np.integer)):
                    p[k] = float(v)
                else:
                    p[k] = v
        fits[i] = fit_class(None, p)
    return fits


def _parallel_fit_worker(vox_data, fit_func, **kwargs):
    """Process a chunk of voxel data.

    When ``_batched`` is True the entire chunk is handed to *fit_func*
    in a single batched call (the ``fit`` method handles the 2-D input
    directly and returns an array of fit objects, or a raw dict when
    ``_raw=True``).  Otherwise each voxel is fitted individually.

    Shared objects arrive as direct keyword arguments prefixed with
    ``_sobj_`` (e.g. ``_sobj__index``).  Ray resolves ``ObjectRef``
    values before invoking the worker, so no explicit ``ray.get()`` is
    required here.

    Parameters
    ----------
    vox_data : ndarray, shape (n_voxels, ...)
        The data to fit.
    fit_func : callable
        ``partial(single_voxel_fit, model)`` — used for both per-voxel
        and batched (``_batched=True``) paths.
    """
    _prefix = "_sobj_"
    shared_objs = {
        k[len(_prefix) :]: kwargs.pop(k) for k in list(kwargs) if k.startswith(_prefix)
    }
    if shared_objs:
        if fit_func.args:
            model = fit_func.args[0]
        elif hasattr(fit_func.func, "__self__"):
            model = fit_func.func.__self__
        else:
            raise ValueError(
                "_parallel_fit_worker: could not locate model instance in "
                "fit_func — shared objects were never applied."
            )
        for name, val in shared_objs.items():
            setattr(model, name, val)

    batched = kwargs.pop("_batched", False)
    vox_weights = kwargs.pop("weights", None)
    if batched:
        if type(vox_weights) is np.ndarray:
            return fit_func(vox_data, weights=vox_weights, **kwargs)
        return fit_func(vox_data, **kwargs)

    if type(vox_weights) is np.ndarray:
        return [
            fit_func(data, **(dict({"weights": weights}, **kwargs)))
            for data, weights in zip(vox_data, vox_weights)
        ]
    return [fit_func(data, **kwargs) for data in vox_data]


[docs] def multi_voxel_fit( _func=None, *, batched=False, shared_obj=None, chunk_size=None, ): """Method decorator to turn a single voxel model fit definition into a multi voxel model fit definition. Supports two calling styles:: @multi_voxel_fit # existing models — unchanged @multi_voxel_fit(batched=True) # batched models (e.g. FORCE) When ``batched=True`` the decorated ``fit`` method must accept both 1-D (single voxel) and 2-D (batch) input and return a single fit object or a 1-D object array of fit objects respectively. **Raw-dict protocol** — if the model declares a ``_fit_class`` attribute, the decorator automatically injects ``_raw=True`` into worker kwargs so that ``fit`` returns a plain ``dict`` of numpy arrays instead of Python fit objects. After all chunks are collected, the generic :func:`_assemble_results` helper is called once in the caller process to build the final fit array via ``_fit_class(None, params)``. This avoids serialising Python objects through the Ray object store and is available to any batched model that sets ``_fit_class``. Parameters ---------- _func : callable, optional When the decorator is used without parentheses (``@multi_voxel_fit``) Python passes the decorated function here directly. batched : bool, optional When True the fit method handles batched 2-D input itself. shared_obj : tuple of str, optional Names of model attributes to place into the Ray object store once and reuse across workers (avoids per-task serialization of large arrays). Example: ``("_penalty_array", "_index", "simulations")``. Only active when ``engine="ray"``. chunk_size : int or dict, optional Number of voxels per chunk. Accepts either a single ``int`` (applied to all engines) or a ``dict`` mapping engine names to chunk sizes, e.g. ``{"serial": 10_000, "ray": 100_000}``. Overridden at call time by the ``vox_per_chunk`` keyword argument. Defaults to 10 000 for the serial-batched path and ``n_vox // n_jobs`` for parallel engines. """ def decorator(single_voxel_fit): def new_fit(self, data, *, mask=None, **kwargs): """Fit method for every voxel in data""" # If only one voxel just return a standard fit, passing through # the functions key-word arguments (no mask needed). if data.ndim == 1: svf = single_voxel_fit(self, data, **kwargs) # If fit method does not return extra, cannot return extra if isinstance(svf, tuple): svf, extra = svf return svf, extra else: return svf # Make a mask if mask is None if mask is None: mask = np.ones(data.shape[:-1], bool) # Check the shape of the mask if mask is not None elif mask.shape != data.shape[:-1]: raise ValueError("mask and data shape do not match") # Get weights from kwargs if provided weights = kwargs["weights"] if "weights" in kwargs else None weights_is_array = True if type(weights) is np.ndarray else False # Fit data where mask is True fit_array = np.empty(data.shape[:-1], dtype=object) return_extra = False # Default to serial execution: engine = kwargs.get("engine", "serial") def _shared_obj_nbytes(): total = 0 for name in shared_obj or (): obj = getattr(self, name, None) if obj is None: continue if hasattr(obj, "nbytes"): total += obj.nbytes elif isinstance(obj, dict): total += sum( v.nbytes for v in obj.values() if hasattr(v, "nbytes") ) return total def _resolve_chunk_size(*, default, n_jobs=None, n_vox=None): explicit = kwargs.get("vox_per_chunk") if explicit is not None: return explicit val = ( chunk_size.get(engine, default) if isinstance(chunk_size, dict) else (chunk_size if chunk_size is not None else default) ) if val == "auto": if engine == "ray" and n_jobs is not None: return auto_ray_chunk_size( n_jobs=n_jobs, n_gradients=data.shape[-1], n_vox=n_vox, shared_obj_nbytes=_shared_obj_nbytes(), ) return default return val use_raw = batched and hasattr(self, "_fit_class") if engine == "serial" and batched: # Batched serial path — pass the whole chunk to fit() at once data_to_fit = data[np.where(mask)] vox_per_chunk = _resolve_chunk_size(default=10000) n_vox = data_to_fit.shape[0] all_chunk_results = [] bar = tqdm( total=n_vox, position=0, disable=not kwargs.get("verbose", False), ) bar.set_description("Fitting (batched serial)") fit_kwargs = { k: v for k, v in kwargs.items() if k not in ("engine", "n_jobs", "vox_per_chunk", "verbose") } if use_raw: fit_kwargs["_raw"] = True for start in range(0, n_vox, vox_per_chunk): chunk = data_to_fit[start : start + vox_per_chunk] chunk_result = single_voxel_fit(self, chunk, **fit_kwargs) all_chunk_results.append(chunk_result) bar.update(len(chunk)) bar.close() if use_raw: tmp_fit_array = _assemble_results( all_chunk_results, fit_class=self._fit_class ) else: tmp_fit_array = np.concatenate(all_chunk_results) fit_array[np.where(mask)] = tmp_fit_array elif engine == "serial": extra_list = [] bar = tqdm( total=np.sum(mask), position=0, disable=not kwargs.get("verbose", False), ) bar.set_description( "Fitting reconstruction model using serial execution" ) for ijk in ndindex(data.shape[:-1]): if mask[ijk]: if weights_is_array: kwargs["weights"] = weights[ijk] svf = single_voxel_fit(self, data[ijk], **kwargs) # Not all fit methods return extra, handle this here if isinstance(svf, tuple): fit_array[ijk], extra = svf return_extra = True else: fit_array[ijk], extra = svf, None extra_list.append(extra) bar.update() bar.close() else: data_to_fit = data[np.where(mask)] if weights_is_array: weights_to_fit = weights[np.where(mask)] n_jobs = kwargs.get("n_jobs", max(multiprocessing.cpu_count() - 1, 1)) n_jobs_eff = determine_num_processes(n_jobs if n_jobs != 0 else 1) n_vox = data_to_fit.shape[0] vox_per_chunk = _resolve_chunk_size( default=max(n_vox // n_jobs_eff, 1), n_jobs=n_jobs_eff, n_vox=n_vox, ) chunks = [ data_to_fit[ii : ii + vox_per_chunk] for ii in range(0, data_to_fit.shape[0], vox_per_chunk) ] # Extract shared objects *before* creating the partial so # that ``self`` is lightweight when serialised by Ray. shared_objects = None if engine == "ray" and shared_obj: shared_objects = {name: getattr(self, name) for name in shared_obj} for name in shared_obj: setattr(self, name, None) try: fit_func = partial(single_voxel_fit, self) # Build per-chunk kwargs kwargs_chunks = [] for ii in range(0, data_to_fit.shape[0], vox_per_chunk): kw = kwargs.copy() if batched: kw["_batched"] = True if use_raw: kw["_raw"] = True if weights_is_array: kw["weights"] = weights_to_fit[ii : ii + vox_per_chunk] kwargs_chunks.append(kw) parallel_kwargs = {} for kk in [ "n_jobs", "vox_per_chunk", "engine", "verbose", "inflight_cap", ]: if kk in kwargs: parallel_kwargs[kk] = kwargs[kk] if shared_objects is not None: parallel_kwargs["shared_objects"] = shared_objects mvf = paramap( _parallel_fit_worker, chunks, func_args=[fit_func], func_kwargs=kwargs_chunks, **parallel_kwargs, ) finally: # Always restore the model even on error if shared_objects is not None: for name, val in shared_objects.items(): setattr(self, name, val) if batched: if use_raw: tmp_fit_array = _assemble_results( mvf, fit_class=self._fit_class ) else: tmp_fit_array = np.concatenate(mvf) fit_array[np.where(mask)] = tmp_fit_array extra_list = None elif isinstance(mvf[0][0], tuple): tmp_fit_array = np.concatenate( [[svf[0] for svf in mvf_ch] for mvf_ch in mvf] ) tmp_extra = np.concatenate( [[svf[1] for svf in mvf_ch] for mvf_ch in mvf] ).tolist() fit_array[np.where(mask)], extra_list = tmp_fit_array, tmp_extra return_extra = True else: tmp_fit_array = np.concatenate(mvf) fit_array[np.where(mask)], extra_list = tmp_fit_array, None # Redefine extra to be a single dictionary if return_extra: if extra_list[0] is not None: extra_mask = { key: np.vstack([e[key] for e in extra_list]) for key in extra_list[0] } extra = {} for key in extra_mask: extra[key] = np.zeros(data.shape) extra[key][mask == 1] = extra_mask[key] else: extra = None # If fit method does not return extra, assume we cannot return extra if return_extra: return MultiVoxelFit(self, fit_array, mask), extra else: return MultiVoxelFit(self, fit_array, mask) return new_fit if _func is not None: # Decorator used without parentheses: @multi_voxel_fit return decorator(_func) # Decorator used with parentheses: @multi_voxel_fit(batched=True) return decorator
[docs] class MultiVoxelFit(ReconstFit): """Holds an array of fits and allows access to their attributes and methods""" def __init__(self, model, fit_array, mask): self.model = model self.fit_array = fit_array self.mask = mask @property def shape(self): return self.fit_array.shape def __getattr__(self, attr): result = CallableArray(self.fit_array.shape, dtype=object) for ijk in ndindex(result.shape): if self.mask[ijk]: result[ijk] = getattr(self.fit_array[ijk], attr) return _squash(result, self.mask) def __getitem__(self, index): item = self.fit_array[index] if isinstance(item, np.ndarray): return MultiVoxelFit(self.model, item, self.mask[index]) else: return item
[docs] def predict(self, *args, **kwargs): """ Predict for the multi-voxel object using each single-object's prediction API, with S0 provided from an array. """ S0 = kwargs.get("S0", np.ones(self.fit_array.shape)) idx = ndindex(self.fit_array.shape) ijk = next(idx) def gimme_S0(S0, ijk): if isinstance(S0, np.ndarray): return S0[ijk] else: return S0 kwargs["S0"] = gimme_S0(S0, ijk) # If we have a mask, we might have some Nones up front, skip those: while self.fit_array[ijk] is None: ijk = next(idx) if not hasattr(self.fit_array[ijk], "predict"): msg = "This model does not have prediction implemented yet" raise NotImplementedError(msg) first_pred = self.fit_array[ijk].predict(*args, **kwargs) result = np.zeros(self.fit_array.shape + (first_pred.shape[-1],)) result[ijk] = first_pred for ijk in idx: kwargs["S0"] = gimme_S0(S0, ijk) # If it's masked, we predict a 0: if self.fit_array[ijk] is None: result[ijk] *= 0 else: result[ijk] = self.fit_array[ijk].predict(*args, **kwargs) return result
[docs] class CallableArray(np.ndarray): """An array which can be called like a function""" def __call__(self, *args, **kwargs): result = np.empty(self.shape, dtype=object) for ijk in ndindex(self.shape): item = self[ijk] if item is not None: result[ijk] = item(*args, **kwargs) return _squash(result)