Source code for dipy.workflows.tracking

#!/usr/bin/env python3

import logging

from dipy.direction import (
    ClosestPeakDirectionGetter,
    DeterministicMaximumDirectionGetter,
    ProbabilisticDirectionGetter,
)
from dipy.io.image import load_nifti
from dipy.io.peaks import load_pam
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.tracking import utils
from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking
from dipy.tracking.stopping_criterion import (
    BinaryStoppingCriterion,
    CmcStoppingCriterion,
    ThresholdStoppingCriterion,
)
from dipy.workflows.workflow import Workflow


[docs] class LocalFiberTrackingPAMFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "track_local"
def _get_direction_getter(self, strategy_name, pam, pmf_threshold, max_angle): """Get Tracking Direction Getter object. Parameters ---------- strategy_name : str String representing direction getter name. pam : instance of PeaksAndMetrics An object with ``gfa``, ``peak_directions``, ``peak_values``, ``peak_indices``, ``odf``, ``shm_coeffs`` as attributes. pmf_threshold : float Threshold for ODF functions. max_angle : float Maximum angle between streamline segments. Returns ------- direction_getter : instance of DirectionGetter Used to get directions for fiber tracking. """ dg, msg = None, "" if strategy_name.lower() in ["deterministic", "det"]: msg = "Deterministic" dg = DeterministicMaximumDirectionGetter.from_shcoeff( pam.shm_coeff, sphere=pam.sphere, max_angle=max_angle, pmf_threshold=pmf_threshold, ) elif strategy_name.lower() in ["probabilistic", "prob"]: msg = "Probabilistic" dg = ProbabilisticDirectionGetter.from_shcoeff( pam.shm_coeff, sphere=pam.sphere, max_angle=max_angle, pmf_threshold=pmf_threshold, ) elif strategy_name.lower() in ["closestpeaks", "cp"]: msg = "ClosestPeaks" dg = ClosestPeakDirectionGetter.from_shcoeff( pam.shm_coeff, sphere=pam.sphere, max_angle=max_angle, pmf_threshold=pmf_threshold, ) elif strategy_name.lower() in [ "eudx", ]: msg = "Eudx" dg = pam else: msg = "No direction getter defined. Eudx" dg = pam logging.info(f"{msg} direction getter strategy selected") return dg def _core_run( self, stopping_path, use_binary_mask, stopping_thr, seeding_path, seed_density, step_size, direction_getter, out_tract, save_seeds, ): stop, affine = load_nifti(stopping_path) if use_binary_mask: stopping_criterion = BinaryStoppingCriterion(stop > stopping_thr) else: stopping_criterion = ThresholdStoppingCriterion(stop, stopping_thr) logging.info("stopping criterion done") seed_mask, _ = load_nifti(seeding_path) seeds = utils.seeds_from_mask( seed_mask, affine, density=[seed_density, seed_density, seed_density] ) logging.info("seeds done") tracking_result = LocalTracking( direction_getter, stopping_criterion, seeds, affine, step_size=step_size, save_seeds=save_seeds, ) logging.info("LocalTracking initiated") if save_seeds: streamlines, seeds = zip(*tracking_result) seeds = {"seeds": seeds} else: streamlines = list(tracking_result) seeds = {} sft = StatefulTractogram( streamlines, seeding_path, Space.RASMM, data_per_streamline=seeds ) save_tractogram(sft, out_tract, bbox_valid_check=False) logging.info(f"Saved {out_tract}")
[docs] def run( self, pam_files, stopping_files, seeding_files, use_binary_mask=False, stopping_thr=0.2, seed_density=1, step_size=0.5, tracking_method="eudx", pmf_threshold=0.1, max_angle=30.0, out_dir="", out_tractogram="tractogram.trk", save_seeds=False, ): """Workflow for Local Fiber Tracking. This workflow use a saved peaks and metrics (PAM) file as input. See :footcite:p:`Garyfallidis2012b` and :footcite:p:`Amirbekian2016` for further details about the method. Parameters ---------- pam_files : string Path to the peaks and metrics files. This path may contain wildcards to use multiple masks at once. stopping_files : string Path to images (e.g. FA) used for stopping criterion for tracking. seeding_files : string A binary image showing where we need to seed for tracking. use_binary_mask : bool, optional If True, uses a binary stopping criterion. If the provided `stopping_files` are not binary, `stopping_thr` will be used to binarize the images. stopping_thr : float, optional Threshold applied to stopping volume's data to identify where tracking has to stop. seed_density : int, optional Number of seeds per dimension inside voxel. For example, seed_density of 2 means 8 regularly distributed points in the voxel. And seed density of 1 means 1 point at the center of the voxel. step_size : float, optional Step size (in mm) used for tracking. tracking_method : string, optional Select direction getter strategy : - "eudx" (Uses the peaks saved in the pam_files) - "deterministic" or "det" for a deterministic tracking (Uses the sh saved in the pam_files, default) - "probabilistic" or "prob" for a Probabilistic tracking (Uses the sh saved in the pam_files) - "closestpeaks" or "cp" for a ClosestPeaks tracking (Uses the sh saved in the pam_files) pmf_threshold : float, optional Threshold for ODF functions. max_angle : float, optional Maximum angle between streamline segments (range [0, 90]). out_dir : string, optional Output directory. (default current directory) out_tractogram : string, optional Name of the tractogram file to be saved. save_seeds : bool, optional If true, save the seeds associated to their streamline in the 'data_per_streamline' Tractogram dictionary using 'seeds' as the key. References ---------- .. footbibliography:: """ io_it = self.get_io_iterator() for pams_path, stopping_path, seeding_path, out_tract in io_it: logging.info(f"Local tracking on {pams_path}") pam = load_pam(pams_path, verbose=False) dg = self._get_direction_getter( tracking_method, pam, pmf_threshold=pmf_threshold, max_angle=max_angle ) self._core_run( stopping_path, use_binary_mask, stopping_thr, seeding_path, seed_density, step_size, dg, out_tract, save_seeds, )
[docs] class PFTrackingPAMFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "track_pft"
[docs] def run( self, pam_files, wm_files, gm_files, csf_files, seeding_files, step_size=0.2, seed_density=1, pmf_threshold=0.1, max_angle=20.0, pft_back=2, pft_front=1, pft_count=15, out_dir="", out_tractogram="tractogram.trk", save_seeds=False, min_wm_pve_before_stopping=0, ): """Workflow for Particle Filtering Tracking. This workflow uses a saved peaks and metrics (PAM) file as input. See :footcite:p:`Girard2014` for further details about the method. Parameters ---------- pam_files : string Path to the peaks and metrics files. This path may contain wildcards to use multiple masks at once. wm_files : string Path to white matter partial volume estimate for tracking (CMC). gm_files : string Path to grey matter partial volume estimate for tracking (CMC). csf_files : string Path to cerebrospinal fluid partial volume estimate for tracking (CMC). seeding_files : string A binary image showing where we need to seed for tracking. step_size : float, optional Step size (in mm) used for tracking. seed_density : int, optional Number of seeds per dimension inside voxel. For example, seed_density of 2 means 8 regularly distributed points in the voxel. And seed density of 1 means 1 point at the center of the voxel. pmf_threshold : float, optional Threshold for ODF functions. max_angle : float, optional Maximum angle between streamline segments (range [0, 90]). pft_back : float, optional Distance in mm to back track before starting the particle filtering tractography. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. pft_front : float, optional Distance in mm to run the particle filtering tractography after the the back track distance. The total particle filtering tractography distance is equal to back_tracking_dist + front_tracking_dist. pft_count : int, optional Number of particles to use in the particle filter. out_dir : string, optional Output directory. (default current directory) out_tractogram : string, optional Name of the tractogram file to be saved. save_seeds : bool, optional If true, save the seeds associated to their streamline in the 'data_per_streamline' Tractogram dictionary using 'seeds' as the key. min_wm_pve_before_stopping : int, optional Minimum white matter pve (1 - stopping_criterion.include_map - stopping_criterion.exclude_map) to reach before allowing the tractography to stop. References ---------- .. footbibliography:: """ io_it = self.get_io_iterator() for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract in io_it: logging.info(f"Particle Filtering tracking on {pams_path}") pam = load_pam(pams_path, verbose=False) wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True) gm, _ = load_nifti(gm_path) csf, _ = load_nifti(csf_path) avs = sum(voxel_size) / len(voxel_size) # average_voxel_size stopping_criterion = CmcStoppingCriterion.from_pve( wm, gm, csf, step_size=step_size, average_voxel_size=avs ) logging.info("stopping criterion done") seed_mask, _ = load_nifti(seeding_path) seeds = utils.seeds_from_mask( seed_mask, affine, density=[seed_density, seed_density, seed_density] ) logging.info("seeds done") dg = ProbabilisticDirectionGetter direction_getter = dg.from_shcoeff( pam.shm_coeff, max_angle=max_angle, sphere=pam.sphere, pmf_threshold=pmf_threshold, ) tracking_result = ParticleFilteringTracking( direction_getter, stopping_criterion, seeds, affine, step_size=step_size, pft_back_tracking_dist=pft_back, pft_front_tracking_dist=pft_front, pft_max_trial=20, particle_count=pft_count, save_seeds=save_seeds, min_wm_pve_before_stopping=min_wm_pve_before_stopping, ) logging.info("ParticleFilteringTracking initiated") if save_seeds: streamlines, seeds = zip(*tracking_result) seeds = {"seeds": seeds} else: streamlines = list(tracking_result) seeds = {} sft = StatefulTractogram( streamlines, seeding_path, Space.RASMM, data_per_streamline=seeds ) save_tractogram(sft, out_tract, bbox_valid_check=False) logging.info(f"Saved {out_tract}")