Source code for dipy.reconst.force

import json
import os
from pathlib import Path
import sys
import warnings

import numpy as np

from dipy.reconst._force_search import search_inner_product as _cython_search
from dipy.reconst.base import ReconstFit, ReconstModel
from dipy.reconst.multi_voxel import multi_voxel_fit

# Named constants
EPSILON = 1e-12


def _get_force_cache_dir():
    """Return the FORCE simulation cache directory inside .dipy.

    Uses ``DIPY_HOME`` environment variable if set, otherwise defaults
    to ``~/.dipy/force_simulations``.

    Returns
    -------
    cache_dir : Path
        Path to the cache directory (created if it does not exist).
    """
    if "DIPY_HOME" in os.environ:
        dipy_home = Path(os.environ["DIPY_HOME"])
    else:
        dipy_home = Path("~").expanduser() / ".dipy"
    cache_dir = dipy_home / "force_simulations"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def _gtab_matches(entry_bvals, entry_bvecs, gtab, *, bval_tol=10.0, bvec_tol=1e-3):
    """Check whether stored bvals/bvecs match a GradientTable.

    Parameters
    ----------
    entry_bvals : list
        Stored b-values from the cache registry.
    entry_bvecs : list of list
        Stored b-vectors from the cache registry.
    gtab : GradientTable
        Gradient table to compare against.
    bval_tol : float, optional
        Absolute tolerance for b-value comparison.
    bvec_tol : float, optional
        Absolute tolerance for b-vector coordinate comparison.

    Returns
    -------
    match : bool
        True if the stored and passed bvals/bvecs agree within tolerance.
    """
    stored_bvals = np.asarray(entry_bvals, dtype=np.float64)
    stored_bvecs = np.asarray(entry_bvecs, dtype=np.float64)
    current_bvals = np.asarray(gtab.bvals, dtype=np.float64)
    current_bvecs = np.asarray(gtab.bvecs, dtype=np.float64)

    if stored_bvals.shape != current_bvals.shape:
        return False
    if stored_bvecs.shape != current_bvecs.shape:
        return False

    return np.allclose(stored_bvals, current_bvals, atol=bval_tol) and np.allclose(
        stored_bvecs, current_bvecs, atol=bvec_tol
    )


def _diffusivity_matches(entry_config, current_config):
    """Check whether two diffusivity configurations are equivalent.

    Parameters
    ----------
    entry_config : dict
        Stored diffusivity configuration.
    current_config : dict
        Current diffusivity configuration.

    Returns
    -------
    match : bool
        True if all keys and values are identical.
    """
    if set(entry_config.keys()) != set(current_config.keys()):
        return False
    for key in entry_config:
        stored = entry_config[key]
        current = current_config[key]
        # Both may be lists/tuples (ranges) or scalars
        if isinstance(stored, (list, tuple)):
            if not isinstance(current, (list, tuple)):
                return False
            if len(stored) != len(current):
                return False
            if not all(np.isclose(s, c) for s, c in zip(stored, current)):
                return False
        else:
            if not np.isclose(stored, current):
                return False
    return True


def _load_cache_registry(cache_dir):
    """Load the cache registry JSON from *cache_dir*.

    Returns an empty list if the file does not exist yet.
    """
    registry_path = cache_dir / "cache_registry.json"
    if registry_path.exists():
        with open(registry_path, "r") as f:
            return json.load(f)
    return []


def _save_cache_registry(cache_dir, registry):
    """Persist *registry* as JSON in *cache_dir*."""
    registry_path = cache_dir / "cache_registry.json"
    with open(registry_path, "w") as f:
        json.dump(registry, f, indent=2)


def _locked_registry_update(cache_dir, update_fn):
    """Read-modify-write the cache registry under an exclusive file lock.

    Parameters
    ----------
    cache_dir : Path
        Cache directory.
    update_fn : callable
        Function that receives the current registry list and returns the
        updated list.
    """
    lock_path = cache_dir / "cache_registry.lock"
    with open(lock_path, "w") as lock_fh:
        if sys.platform == "win32":
            import msvcrt

            msvcrt.locking(lock_fh.fileno(), msvcrt.LK_LOCK, 1)
            try:
                registry = _load_cache_registry(cache_dir)
                registry = update_fn(registry)
                _save_cache_registry(cache_dir, registry)
            finally:
                lock_fh.seek(0)
                msvcrt.locking(lock_fh.fileno(), msvcrt.LK_UNLCK, 1)
        else:
            import fcntl

            fcntl.flock(lock_fh, fcntl.LOCK_EX)
            try:
                registry = _load_cache_registry(cache_dir)
                registry = update_fn(registry)
                _save_cache_registry(cache_dir, registry)
            finally:
                fcntl.flock(lock_fh, fcntl.LOCK_UN)


def _find_cached_simulation(cache_dir, gtab, diffusivity_config, num_simulations):
    """Search the registry for a simulation matching the given parameters.

    Parameters
    ----------
    cache_dir : Path
        Cache directory.
    gtab : GradientTable
        Gradient table.
    diffusivity_config : dict
        Diffusivity ranges used for generation.
    num_simulations : int
        Number of simulations requested.

    Returns
    -------
    path : str or None
        Path to the cached ``.npz`` file, or None if no match found.
    """
    registry = _load_cache_registry(cache_dir)
    for entry in registry:
        if entry["num_simulations"] != num_simulations:
            continue
        if not _gtab_matches(entry["bvals"], entry["bvecs"], gtab):
            continue
        if not _diffusivity_matches(entry["diffusivity_config"], diffusivity_config):
            continue
        candidate = cache_dir / entry["filename"]
        if candidate.exists():
            return str(candidate)
    return None


def _register_cached_simulation(
    cache_dir, gtab, diffusivity_config, num_simulations, filename
):
    """Add a new entry to the cache registry.

    Parameters
    ----------
    cache_dir : Path
        Cache directory.
    gtab : GradientTable
        Gradient table.
    diffusivity_config : dict
        Diffusivity ranges.
    num_simulations : int
        Number of simulations.
    filename : str
        Filename of the saved ``.npz`` inside *cache_dir*.
    """

    # Convert numpy types to plain Python for JSON serialisation
    def _to_json(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, (np.floating, np.integer)):
            return obj.item()
        if isinstance(obj, tuple):
            return list(obj)
        return obj

    config_json = {}
    for k, v in diffusivity_config.items():
        config_json[k] = _to_json(v)

    entry = {
        "bvals": np.asarray(gtab.bvals, dtype=np.float64).tolist(),
        "bvecs": np.asarray(gtab.bvecs, dtype=np.float64).tolist(),
        "diffusivity_config": config_json,
        "num_simulations": int(num_simulations),
        "filename": filename,
    }

    def _append(registry):
        registry.append(entry)
        return registry

    _locked_registry_update(cache_dir, _append)


[docs] class SignalIndex: """Index for inner product similarity search. Uses optimized Cython BLAS for fast matrix multiplication and streaming heap for memory-efficient top-k selection. Parameters ---------- d : int Dimension of vectors. """ def __init__(self, d): if d <= 0: raise ValueError(f"Dimension must be positive, got {d}") self.d = int(d) self.ntotal = 0 self._xb = None
[docs] def add(self, x): """Add vectors to the index. Parameters ---------- x : array-like (n, d) Vectors to add, will be converted to float32 C-contiguous. Notes ----- Each call reallocates the internal array via ``np.vstack``. This method is designed for a single bulk load; repeated small ``add`` calls will exhibit O(n²) memory allocation cost. """ x = np.ascontiguousarray(x, dtype=np.float32) if x.ndim == 1: if len(x) != self.d: raise ValueError( f"Vector dimension {len(x)} != index dimension {self.d}" ) x = x.reshape(1, -1) if x.ndim != 2: raise ValueError(f"Expected 1D or 2D array, got {x.ndim}D") if x.shape[1] != self.d: raise ValueError( f"Vector dimension {x.shape[1]} != index dimension {self.d}" ) if self._xb is None: self._xb = x.copy() else: self._xb = np.vstack([self._xb, x]) self.ntotal = len(self._xb)
[docs] def search(self, x, k): """Search for k nearest neighbors by inner product. Parameters ---------- x : array-like (n, d) or (d,) Query vectors. k : int Number of neighbors. Returns ------- distances : ndarray (n, k) Inner products (descending order). indices : ndarray (n, k) Neighbor indices. """ if self.ntotal == 0: raise RuntimeError("Cannot search empty index") x = np.ascontiguousarray(x, dtype=np.float32) if x.ndim == 1: if len(x) != self.d: raise ValueError( f"Query dimension {len(x)} != index dimension {self.d}" ) x = x.reshape(1, -1) if x.ndim != 2: raise ValueError(f"Expected 1D or 2D array, got {x.ndim}D") if x.shape[1] != self.d: raise ValueError( f"Query dimension {x.shape[1]} != index dimension {self.d}" ) if k <= 0: raise ValueError(f"k must be positive, got {k}") if k > self.ntotal: warnings.warn( f"k={k} exceeds index size ({self.ntotal}); " f"clamping to {self.ntotal}.", UserWarning, stacklevel=2, ) k = self.ntotal # Use optimized Cython search (SciPy BLAS + streaming heap) distances, indices = _cython_search(x, self._xb, k) return distances, indices
[docs] def reset(self): """Remove all vectors from the index.""" self._xb = None self.ntotal = 0
def __repr__(self): return f"SignalIndex(d={self.d}, ntotal={self.ntotal})"
[docs] def normalize_signals(signals): """L2-normalize signal array for cosine similarity search. Parameters ---------- signals : ndarray (N, M) Signal array with N samples and M measurements. Returns ------- normalized : ndarray (N, M) L2-normalized signals. """ signals = np.asarray(signals, dtype=np.float32) norms = np.linalg.norm(signals, axis=1, keepdims=True) norms[norms == 0] = 1.0 return np.ascontiguousarray(signals / norms)
[docs] def create_signal_index(signals_norm): """Create index for cosine similarity search. Parameters ---------- signals_norm : ndarray (N, M) L2-normalized library signals. Returns ------- index : SignalIndex Search index. """ dimension = signals_norm.shape[1] index = SignalIndex(dimension) index.add(signals_norm) return index
[docs] def softmax_stable(x, *, axis=1): """Numerically stable softmax. Parameters ---------- x : ndarray Input array. axis : int, optional Axis along which to compute softmax. Returns ------- softmax : ndarray Softmax probabilities. """ x = x - np.max(x, axis=axis, keepdims=True) ex = np.exp(x) return ex / np.sum(ex, axis=axis, keepdims=True)
[docs] def compute_uncertainty_ambiguity(scores): """Compute uncertainty and ambiguity metrics from match scores. Parameters ---------- scores : ndarray (N, K) Penalized scores for K neighbors. Returns ------- uncertainty : ndarray (N,) IQR of scores. ambiguity : ndarray (N,) Fraction above half-max. """ p75 = np.percentile(scores, 75, axis=1) p25 = np.percentile(scores, 25, axis=1) uncertainty = (p75 - p25).astype(np.float32) s_max = np.max(scores, axis=1) s_min = np.min(scores, axis=1) half = 0.5 * (s_max + s_min) ambiguity = (np.sum(scores > half[:, None], axis=1) / scores.shape[1]).astype( np.float32 ) return uncertainty, ambiguity
[docs] def postprocess_peaks(preds, target_sphere, fracs): original_shape = preds.shape[:-1] preds_flat = preds.reshape(-1, preds.shape[-1]) fracs_flat = fracs.reshape(-1, fracs.shape[-1]) n_voxels = preds_flat.shape[0] vertices = target_sphere.vertices # Initialize outputs using the total number of voxels peaks_output = np.zeros((n_voxels, 5, 3), dtype=np.float32) peak_indices = np.full((n_voxels, 5), -1, dtype=np.int32) peak_vals = np.zeros((n_voxels, 5), dtype=np.float32) for i in range(n_voxels): mask = preds_flat[i] == 1 coords = vertices[mask] indices = np.where(mask)[0] num = min(len(coords), 5) if num > 0: peaks_output[i, :num] = coords[:num] peak_indices[i, :num] = indices[:num] num_fracs = min(5, fracs_flat[i].shape[0]) peak_vals[i, :num_fracs] = fracs_flat[i][:num_fracs] peaks_output = peaks_output.reshape((*original_shape, 5, 3)) peak_indices = peak_indices.reshape((*original_shape, 5)) peak_vals = peak_vals.reshape((*original_shape, 5)) return peaks_output, peak_indices, peak_vals
[docs] class FORCEModel(ReconstModel): """FORCE reconstruction model for signal matching based microstructure.""" def __init__( self, gtab, *, simulations=None, penalty=1e-5, n_neighbors=50, use_posterior=False, posterior_beta=2000.0, compute_odf=False, verbose=False, ): r""" FORCE (FORward modeling for Complex microstructure Estimation) model :footcite:p:`Shah2025`. FORCE is a forward modeling paradigm that reframes how diffusion data is analyzed. Instead of inverting the measured signal, FORCE simulates a large set of biologically plausible intra-voxel fiber configurations and tissue compositions. It then identifies the best-matching simulation for each voxel by operating directly in the signal space. Parameters ---------- gtab : GradientTable Gradient table. simulations : dict or None, optional Pre-computed FORCE simulations with signals and parameters. If None, call generate() to create simulations. penalty : float, optional Penalty weight for fiber complexity. n_neighbors : int, optional Number of neighbors for matching. use_posterior : bool, optional Use posterior averaging instead of best match. posterior_beta : float, optional Softmax temperature for posterior. compute_odf : bool, optional Compute posterior ODF maps. verbose : bool, optional Show progress bar and status messages. Notes ----- The fit method uses the @multi_voxel_fit decorator which supports parallel execution. Pass `engine` and `n_jobs` kwargs to the fit method: Available engines: "serial", "ray", "joblib", "dask". References ---------- .. footbibliography:: """ self.gtab = gtab self.simulations = simulations self.penalty = penalty self.n_neighbors = n_neighbors self.use_posterior = use_posterior self.posterior_beta = posterior_beta self.compute_odf = compute_odf self.verbose = verbose self._index = None self._penalty_array = None if simulations is not None: self._prepare_library()
[docs] def generate( self, *, num_simulations=500000, output_path=None, num_cpus=1, wm_threshold=0.5, tortuosity=False, odi_range=(0.01, 0.3), diffusivity_config=None, compute_dti=True, compute_dki=False, verbose=False, use_cache=True, ): """Generate simulations for matching. When ``output_path`` is ``None`` and ``use_cache`` is ``True``, simulations are cached in ``~/.dipy/force_simulations/`` (or ``$DIPY_HOME``). A registry file (``cache_registry.json``) keeps track of the bvals, bvecs, diffusivity configuration and number of simulations for each cached file. If a cached simulation that matches the current gradient table (within tolerance) and diffusivity configuration already exists, it is loaded from disk and generation is skipped. Set ``use_cache=False`` to force regeneration even when a matching cached simulation exists. Parameters ---------- num_simulations : int, optional Number of simulated voxels. output_path : str, optional Path to save simulations (.npz). When None, saves to ``~/.dipy/force_simulations/`` and uses caching. num_cpus : int, optional Number of CPU cores for parallel processing. wm_threshold : float, optional Minimum WM fraction to include fiber labels. tortuosity : bool, optional Use tortuosity constraint for perpendicular diffusivity. odi_range : tuple, optional (min, max) orientation dispersion index range. diffusivity_config : dict, optional Custom diffusivity ranges. compute_dti : bool, optional Compute DTI metrics (FA, MD, RD). compute_dki : bool, optional Compute DKI metrics (AK, RK, MK, KFA). verbose : bool, optional Enable progress output. use_cache : bool, optional Whether to use cached simulations when ``output_path`` is None. Set to ``False`` to always regenerate. Returns ------- self : FORCEModel Model with simulations loaded. """ from dipy.sims.force import ( generate_force_simulations, get_default_diffusivity_config, load_force_simulations, save_force_simulations, ) # Resolve the diffusivity config that will actually be used resolved_config = ( diffusivity_config if diffusivity_config is not None else get_default_diffusivity_config() ) # --- Cache logic when no explicit output_path is given ---------- if output_path is None and use_cache: cache_dir = _get_force_cache_dir() cached = _find_cached_simulation( cache_dir, self.gtab, resolved_config, num_simulations, ) if cached is not None: if verbose: print(f"[FORCE] Loading cached simulations from {cached}") self.simulations = load_force_simulations(cached) self._prepare_library() return self # --- Generate new simulations ----------------------------------- self.simulations = generate_force_simulations( self.gtab, num_simulations=num_simulations, num_cpus=num_cpus, wm_threshold=wm_threshold, tortuosity=tortuosity, odi_range=odi_range, diffusivity_config=diffusivity_config, compute_dti=compute_dti, compute_dki=compute_dki, verbose=verbose, ) if output_path is not None: save_force_simulations(self.simulations, output_path) else: # Save into the .dipy cache and register. # filename is generated inside the lock to avoid races between # concurrent processes reading the same registry length. cache_dir = _get_force_cache_dir() filename_holder = {} def _append_and_name(registry): idx = len(registry) fname = f"force_sim_{idx}.npz" filename_holder["filename"] = fname return registry # entry added by _register_cached_simulation _locked_registry_update(cache_dir, _append_and_name) filename = filename_holder["filename"] save_force_simulations(self.simulations, str(cache_dir / filename)) _register_cached_simulation( cache_dir, self.gtab, resolved_config, num_simulations, filename, ) if verbose: print(f"[FORCE] Cached simulations to {cache_dir / filename}") self._prepare_library() return self
[docs] def load(self, input_path): """Load pre-computed simulations from file. Parameters ---------- input_path : str Path to simulations file (.npz). Returns ------- self : FORCEModel Model with simulations loaded. """ from dipy.sims.force import load_force_simulations self.simulations = load_force_simulations(input_path) self._prepare_library() return self
def _prepare_library(self): """Prepare library for matching.""" signals = self.simulations["signals"] # Normalize library signals lib_norm = np.linalg.norm(signals, axis=1, keepdims=True) lib_norm[lib_norm == 0] = 1.0 signals_norm = np.ascontiguousarray((signals / lib_norm).astype(np.float32)) # Build index self._index = create_signal_index(signals_norm) # Penalty array num_fibers = self.simulations.get( "num_fibers", np.zeros(len(signals), dtype=np.float32) ) self._penalty_array = (self.penalty * num_fibers).astype(np.float32) @staticmethod def _fetch_params_batched(lib_idx, d): """Vectorised parameter look-up for best-match indices. Parameters ---------- lib_idx : ndarray (N,) Library indices of the best match per voxel. d : dict Simulation dictionary. Returns ------- params : dict of ndarray """ params = { "fa": d["fa"][lib_idx].astype(np.float32), "md": d["md"][lib_idx].astype(np.float32), "rd": d["rd"][lib_idx].astype(np.float32), "wm_fraction": d["wm_fraction"][lib_idx].astype(np.float32), "gm_fraction": d["gm_fraction"][lib_idx].astype(np.float32), "csf_fraction": d["csf_fraction"][lib_idx].astype(np.float32), "num_fibers": d["num_fibers"][lib_idx].astype(np.float32), "dispersion": d["dispersion"][lib_idx].astype(np.float32), "nd": d["nd"][lib_idx].astype(np.float32), "labels": d["labels"][lib_idx].astype(np.int8), "fracs": d["fraction_array"][lib_idx].astype(np.float32), } if "ufa_wm" in d: params["ufa_wm"] = d["ufa_wm"][lib_idx].astype(np.float32) params["ufa_voxel"] = d["ufa_voxel"][lib_idx].astype(np.float32) if "ak" in d: params["ak"] = d["ak"][lib_idx].astype(np.float32) params["rk"] = d["rk"][lib_idx].astype(np.float32) params["mk"] = d["mk"][lib_idx].astype(np.float32) params["kfa"] = d["kfa"][lib_idx].astype(np.float32) params["odf"] = d["odfs"][lib_idx].astype(np.float32) if "odfs" in d else None params["predicted_signal"] = d["signals"][lib_idx].astype(np.float32) return params @staticmethod def _posterior_params_batched(neighbors, W, d, lib_idx): """Vectorised posterior-averaging over neighbours. Parameters ---------- neighbors : ndarray (N, K) Neighbour indices. W : ndarray (N, K) Posterior weights. d : dict Simulation dictionary. lib_idx : int Index of the exact match in the simulation Returns ------- params : dict of ndarray """ def _wavg(field): return np.sum(W * d[field][neighbors], axis=1).astype(np.float32) params = { "fa": _wavg("fa"), "md": _wavg("md"), "rd": _wavg("rd"), "wm_fraction": _wavg("wm_fraction"), "gm_fraction": _wavg("gm_fraction"), "csf_fraction": _wavg("csf_fraction"), "num_fibers": _wavg("num_fibers"), "dispersion": _wavg("dispersion"), "nd": _wavg("nd"), "labels": d["labels"][lib_idx].astype(np.int8), "fracs": d["fraction_array"][lib_idx].astype(np.float32), } if "ufa_wm" in d: params["ufa_wm"] = _wavg("ufa_wm") params["ufa_voxel"] = _wavg("ufa_voxel") if "ak" in d: params["ak"] = _wavg("ak") params["rk"] = _wavg("rk") params["mk"] = _wavg("mk") params["kfa"] = _wavg("kfa") # Posterior ODF if "odfs" in d: K = neighbors.shape[1] odf = np.zeros((neighbors.shape[0], d["odfs"].shape[1]), dtype=np.float32) for kk in range(K): odf_k = d["odfs"][neighbors[:, kk]].astype(np.float32) odf_k /= np.max(odf_k, axis=1, keepdims=True) + EPSILON odf += W[:, kk : kk + 1] * odf_k odf /= np.max(odf, axis=1, keepdims=True) + EPSILON params["odf"] = odf else: params["odf"] = None # Posterior mean signal params["predicted_signal"] = posterior_mean_signal(d["signals"], W, neighbors) return params @multi_voxel_fit( batched=True, shared_obj=("_penalty_array", "_index", "simulations"), chunk_size={"serial": 10_000, "ray": "auto"}, ) def fit(self, data, *, mask=None, **kwargs): """Fit model to data. Parameters ---------- data : ndarray Diffusion data for a single voxel (1D) or multiple voxels (ND). mask : ndarray, optional Brain mask (for multi-voxel data). **kwargs : dict Additional arguments passed to multi_voxel_fit decorator: - engine : str, optional Parallel engine: "serial", "ray", "joblib", "dask". - n_jobs : int, optional Number of parallel jobs. - verbose : bool, optional Show progress bar. Returns ------- fit : FORCEFit or ndarray of FORCEFit Fitted model for a single voxel (1-D input) or an object array of fitted models for a batch of voxels (2-D input). Notes ----- This method is decorated with @multi_voxel_fit(batched=True) which handles multi-voxel dispatch, mask application, and aggregation into a MultiVoxelFit. The method itself handles both 1-D (single voxel) and 2-D (batch) inputs directly. For parallel execution, use engine="ray", n_jobs=4 arguments in model fit() call. **Memory warning (joblib / dask engines):** When engine="joblib" or engine="dask", the full simulation library (including the signal matrix and search index, ~120-400 MB for 100k simulations) is pickled and sent to *every* worker chunk. With 8 workers this can consume several gigabytes of RAM. For num_simulations > ~10 000 use engine="ray" instead, which places the library in a shared object store and avoids redundant copies across workers. """ if self.simulations is None: raise RuntimeError( "No simulations loaded. Call generate() or provide simulations." ) if self._index is None: raise RuntimeError( "Signal index is not prepared. Call _prepare_library() or " "reload simulations via generate() or load()." ) single = data.ndim == 1 data2d = data.reshape(1, -1) if single else data data2d = np.ascontiguousarray(data2d, dtype=np.float32) norms = np.linalg.norm(data2d, axis=1, keepdims=True).astype(np.float32) norms[norms == 0] = 1.0 query_norm = np.ascontiguousarray(data2d / norms) D, neighbors = self._index.search(query_norm, k=self.n_neighbors) S = D - self._penalty_array[neighbors] U, A = compute_uncertainty_ambiguity(S) d = self.simulations n_vox = data2d.shape[0] best = np.argmax(S, axis=1) lib_idx = neighbors[np.arange(n_vox), best] if self.use_posterior: W = softmax_stable(self.posterior_beta * S, axis=1) entropy = -np.sum(W * np.log(W + EPSILON), axis=1) params_arrays = self._posterior_params_batched(neighbors, W, d, lib_idx) params_arrays["uncertainty"] = U params_arrays["ambiguity"] = A params_arrays["entropy"] = entropy.astype(np.float32) else: params_arrays = self._fetch_params_batched(lib_idx, d) params_arrays["uncertainty"] = U params_arrays["ambiguity"] = A params_arrays["entropy"] = np.full(n_vox, np.nan, dtype=np.float32) if kwargs.pop("_raw", False): return params_arrays fits = np.empty(n_vox, dtype=object) keys = list(params_arrays.keys()) for i in range(n_vox): p = {} for k in keys: val = params_arrays[k] if val is None: p[k] = None else: v = val[i] if isinstance(v, np.ndarray) and v.ndim == 0: p[k] = float(v) elif isinstance(v, (np.floating, np.integer)): p[k] = float(v) else: p[k] = v entropy_val = p.get("entropy", 0.0) if entropy_val is not None and np.isnan(entropy_val): p["entropy"] = None fits[i] = FORCEFit(None, p) return fits[0] if single else fits
[docs] class FORCEFit(ReconstFit): """FORCE model fit results for a single voxel.""" def __init__(self, model, params): """Initialize a FORCEFit class instance.""" if ( "entropy" in params and params["entropy"] is not None and np.isnan(params["entropy"]) ): params["entropy"] = None self.model = model self._params = params @property def fa(self): """Fractional anisotropy.""" return self._params["fa"] @property def md(self): """Mean diffusivity.""" return self._params["md"] @property def rd(self): """Radial diffusivity.""" return self._params["rd"] @property def wm_fraction(self): """White matter fraction.""" return self._params["wm_fraction"] @property def gm_fraction(self): """Gray matter fraction.""" return self._params["gm_fraction"] @property def csf_fraction(self): """CSF fraction.""" return self._params["csf_fraction"] @property def num_fibers(self): """Number of fibers.""" return self._params["num_fibers"] @property def dispersion(self): """Orientation dispersion.""" return self._params["dispersion"] @property def nd(self): """Neurite density.""" return self._params["nd"] @property def ufa_wm(self): """microFA in white matter.""" return self._params.get("ufa_wm", None) @property def ufa_voxel(self): """Voxel-averaged microFA.""" return self._params.get("ufa_voxel", None) @property def ak(self): """Axial kurtosis (DKI).""" return self._params.get("ak", None) @property def rk(self): """Radial kurtosis (DKI).""" return self._params.get("rk", None) @property def mk(self): """Mean kurtosis (DKI).""" return self._params.get("mk", None) @property def kfa(self): """Kurtosis fractional anisotropy (DKI).""" return self._params.get("kfa", None) @property def odf(self): """Orientation distribution function.""" return self._params.get("odf", None) @property def predicted_signal(self): """Predicted signal from matched library entry (cleaned DWI).""" return self._params.get("predicted_signal", None) @property def uncertainty(self): """Uncertainty (IQR of penalized scores).""" return self._params["uncertainty"] @property def ambiguity(self): """Ambiguity (fraction above half-max).""" return self._params["ambiguity"] @property def entropy(self): """Entropy (posterior mode only).""" return self._params.get("entropy", None) @property def label(self): """Fiber configuration label.""" return self._params.get("labels", None) @property def fracs(self): """Fiber fractions.""" return self._params.get("fracs", None)
# Resolve forward reference: FORCEModel is defined before FORCEFit. FORCEModel._fit_class = FORCEFit
[docs] def compute_entropy(weights): """Compute entropy of posterior weights. Parameters ---------- weights : ndarray (N, K) Posterior weights for K neighbors. Returns ------- entropy : ndarray (N,) Shannon entropy for each sample. """ return (-np.sum(weights * np.log(weights + EPSILON), axis=1)).astype(np.float32)
[docs] def posterior_mean_signal(signals, weights, indices): """Compute posterior mean signal from neighbors. Parameters ---------- signals : ndarray (N_lib, M) Library signals. weights : ndarray (N_query, K) Posterior weights. indices : ndarray (N_query, K) Neighbor indices. Returns ------- mean_signal : ndarray (N_query, M) Posterior mean signals. """ n_query = indices.shape[0] n_grad = signals.shape[1] k = indices.shape[1] result = np.zeros((n_query, n_grad), dtype=np.float32) for kk in range(k): result += weights[:, kk : kk + 1] * signals[indices[:, kk]] return result
[docs] def posterior_odf(odfs, weights, indices, n_dirs): """Compute posterior ODF from neighbors. Parameters ---------- odfs : ndarray (N_lib, D) Library ODFs. weights : ndarray (N_query, K) Posterior weights. indices : ndarray (N_query, K) Neighbor indices. n_dirs : int Number of sphere directions. Returns ------- odf : ndarray (N_query, D) Posterior mean ODFs. """ n_query = indices.shape[0] k = indices.shape[1] result = np.zeros((n_query, n_dirs), dtype=np.float32) for kk in range(k): odf_k = odfs[indices[:, kk]].astype(np.float32) odf_k /= np.max(odf_k, axis=1, keepdims=True) + EPSILON result += weights[:, kk : kk + 1] * odf_k result /= np.max(result, axis=1, keepdims=True) + EPSILON return result
[docs] def force_peaks(fitted_object, *, mask=None, sh_order=8): """Create a PeaksAndMetrics object from a FORCEFit or MultiVoxelFit. Parameters ---------- fitted_object : FORCEFit or MultiVoxelFit The result of model.fit(). mask : ndarray, optional Optional brain mask. sh_order : int, optional Spherical harmonics order for the coefficients. """ from dipy.direction.peaks import PeaksAndMetrics from dipy.reconst.shm import sf_to_sh from dipy.sims.force import default_sphere labels = fitted_object.label fracs = fitted_object.fracs odf = fitted_object.odf is_multi_voxel = labels.ndim > 1 if not is_multi_voxel: # Single Voxel Case p_out, p_ind, p_val = postprocess_peaks( labels[None, :], default_sphere, fracs[None, :] ) res_dirs, res_inds, res_vals = p_out[0], p_ind[0], p_val[0] res_sh = ( sf_to_sh(odf, default_sphere, sh_order=sh_order) if odf is not None else None ) else: # Multi-Voxel / CLI Case original_shape = labels.shape[:-1] # (X, Y, Z) if mask is not None: labels_to_proc = labels[mask] fracs_to_proc = fracs[mask] else: labels_to_proc = labels.reshape(-1, labels.shape[-1]) fracs_to_proc = fracs.reshape(-1, fracs.shape[-1]) p_out, p_ind, p_val = postprocess_peaks( labels_to_proc, default_sphere, fracs_to_proc ) res_dirs = np.zeros((*original_shape, 5, 3), dtype=np.float32) res_inds = np.full((*original_shape, 5), -1, dtype=np.int32) res_vals = np.zeros((*original_shape, 5), dtype=np.float32) if mask is not None: res_dirs[mask] = p_out res_inds[mask] = p_ind res_vals[mask] = p_val else: res_dirs = p_out.reshape((*original_shape, 5, 3)) res_inds = p_ind.reshape((*original_shape, 5)) res_vals = p_val.reshape((*original_shape, 5)) res_sh = None if odf is not None and np.issubdtype( getattr(odf, "dtype", type(None)), np.floating ): v_max = np.max(odf, axis=-1, keepdims=True) v_min = np.min(odf, axis=-1, keepdims=True) mask = v_max > 1.0 denom = (v_max - v_min) + 1e-12 normalized_odf = (odf - v_min) / denom odf = np.where(mask, normalized_odf, odf) res_sh = sf_to_sh(odf, default_sphere, sh_order=sh_order) peaks = PeaksAndMetrics() peaks.peak_dirs = res_dirs peaks.peak_values = res_vals peaks.peak_indices = res_inds peaks.shm_coeff = res_sh peaks.sphere = default_sphere return peaks