#!/usr/bin/env python3
import logging
from dipy.direction import (
    ClosestPeakDirectionGetter,
    DeterministicMaximumDirectionGetter,
    ProbabilisticDirectionGetter,
)
from dipy.io.image import load_nifti
from dipy.io.peaks import load_pam
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.tracking import utils
from dipy.tracking.local_tracking import LocalTracking, ParticleFilteringTracking
from dipy.tracking.stopping_criterion import (
    BinaryStoppingCriterion,
    CmcStoppingCriterion,
    ThresholdStoppingCriterion,
)
from dipy.workflows.workflow import Workflow
[docs]
class LocalFiberTrackingPAMFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "track_local" 
    def _get_direction_getter(self, strategy_name, pam, pmf_threshold, max_angle):
        """Get Tracking Direction Getter object.
        Parameters
        ----------
        strategy_name : str
            String representing direction getter name.
        pam : instance of PeaksAndMetrics
            An object with ``gfa``, ``peak_directions``, ``peak_values``,
            ``peak_indices``, ``odf``, ``shm_coeffs`` as attributes.
        pmf_threshold : float
            Threshold for ODF functions.
        max_angle : float
            Maximum angle between streamline segments.
        Returns
        -------
        direction_getter : instance of DirectionGetter
            Used to get directions for fiber tracking.
        """
        dg, msg = None, ""
        if strategy_name.lower() in ["deterministic", "det"]:
            msg = "Deterministic"
            dg = DeterministicMaximumDirectionGetter.from_shcoeff(
                pam.shm_coeff,
                sphere=pam.sphere,
                max_angle=max_angle,
                pmf_threshold=pmf_threshold,
            )
        elif strategy_name.lower() in ["probabilistic", "prob"]:
            msg = "Probabilistic"
            dg = ProbabilisticDirectionGetter.from_shcoeff(
                pam.shm_coeff,
                sphere=pam.sphere,
                max_angle=max_angle,
                pmf_threshold=pmf_threshold,
            )
        elif strategy_name.lower() in ["closestpeaks", "cp"]:
            msg = "ClosestPeaks"
            dg = ClosestPeakDirectionGetter.from_shcoeff(
                pam.shm_coeff,
                sphere=pam.sphere,
                max_angle=max_angle,
                pmf_threshold=pmf_threshold,
            )
        elif strategy_name.lower() in [
            "eudx",
        ]:
            msg = "Eudx"
            dg = pam
        else:
            msg = "No direction getter defined. Eudx"
            dg = pam
        logging.info(f"{msg} direction getter strategy selected")
        return dg
    def _core_run(
        self,
        stopping_path,
        use_binary_mask,
        stopping_thr,
        seeding_path,
        seed_density,
        step_size,
        direction_getter,
        out_tract,
        save_seeds,
    ):
        stop, affine = load_nifti(stopping_path)
        if use_binary_mask:
            stopping_criterion = BinaryStoppingCriterion(stop > stopping_thr)
        else:
            stopping_criterion = ThresholdStoppingCriterion(stop, stopping_thr)
        logging.info("stopping criterion done")
        seed_mask, _ = load_nifti(seeding_path)
        seeds = utils.seeds_from_mask(
            seed_mask, affine, density=[seed_density, seed_density, seed_density]
        )
        logging.info("seeds done")
        tracking_result = LocalTracking(
            direction_getter,
            stopping_criterion,
            seeds,
            affine,
            step_size=step_size,
            save_seeds=save_seeds,
        )
        logging.info("LocalTracking initiated")
        if save_seeds:
            streamlines, seeds = zip(*tracking_result)
            seeds = {"seeds": seeds}
        else:
            streamlines = list(tracking_result)
            seeds = {}
        sft = StatefulTractogram(
            streamlines, seeding_path, Space.RASMM, data_per_streamline=seeds
        )
        save_tractogram(sft, out_tract, bbox_valid_check=False)
        logging.info(f"Saved {out_tract}")
[docs]
    def run(
        self,
        pam_files,
        stopping_files,
        seeding_files,
        use_binary_mask=False,
        stopping_thr=0.2,
        seed_density=1,
        step_size=0.5,
        tracking_method="eudx",
        pmf_threshold=0.1,
        max_angle=30.0,
        out_dir="",
        out_tractogram="tractogram.trk",
        save_seeds=False,
    ):
        """Workflow for Local Fiber Tracking.
        This workflow use a saved peaks and metrics (PAM) file as input.
        See :footcite:p:`Garyfallidis2012b` and :footcite:p:`Amirbekian2016`
        for further details about the method.
        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        stopping_files : string
            Path to images (e.g. FA) used for stopping criterion for tracking.
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        use_binary_mask : bool, optional
            If True, uses a binary stopping criterion. If the provided
            `stopping_files` are not binary, `stopping_thr` will be used to
            binarize the images.
        stopping_thr : float, optional
            Threshold applied to stopping volume's data to identify where
            tracking has to stop.
        seed_density : int, optional
            Number of seeds per dimension inside voxel.
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        step_size : float, optional
            Step size (in mm) used for tracking.
        tracking_method : string, optional
            Select direction getter strategy :
             - "eudx" (Uses the peaks saved in the pam_files)
             - "deterministic" or "det" for a deterministic tracking
               (Uses the sh saved in the pam_files, default)
             - "probabilistic" or "prob" for a Probabilistic tracking
               (Uses the sh saved in the pam_files)
             - "closestpeaks" or "cp" for a ClosestPeaks tracking
               (Uses the sh saved in the pam_files)
        pmf_threshold : float, optional
            Threshold for ODF functions.
        max_angle : float, optional
            Maximum angle between streamline segments (range [0, 90]).
        out_dir : string, optional
           Output directory. (default current directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved.
        save_seeds : bool, optional
            If true, save the seeds associated to their streamline
            in the 'data_per_streamline' Tractogram dictionary using
            'seeds' as the key.
        References
        ----------
        .. footbibliography::
        """
        io_it = self.get_io_iterator()
        for pams_path, stopping_path, seeding_path, out_tract in io_it:
            logging.info(f"Local tracking on {pams_path}")
            pam = load_pam(pams_path, verbose=False)
            dg = self._get_direction_getter(
                tracking_method, pam, pmf_threshold=pmf_threshold, max_angle=max_angle
            )
            self._core_run(
                stopping_path,
                use_binary_mask,
                stopping_thr,
                seeding_path,
                seed_density,
                step_size,
                dg,
                out_tract,
                save_seeds,
            ) 
 
[docs]
class PFTrackingPAMFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "track_pft" 
[docs]
    def run(
        self,
        pam_files,
        wm_files,
        gm_files,
        csf_files,
        seeding_files,
        step_size=0.2,
        seed_density=1,
        pmf_threshold=0.1,
        max_angle=20.0,
        pft_back=2,
        pft_front=1,
        pft_count=15,
        out_dir="",
        out_tractogram="tractogram.trk",
        save_seeds=False,
        min_wm_pve_before_stopping=0,
    ):
        """Workflow for Particle Filtering Tracking.
        This workflow uses a saved peaks and metrics (PAM) file as input.
        See :footcite:p:`Girard2014` for further details about the method.
        Parameters
        ----------
        pam_files : string
           Path to the peaks and metrics files. This path may contain
            wildcards to use multiple masks at once.
        wm_files : string
            Path to white matter partial volume estimate for tracking (CMC).
        gm_files : string
            Path to grey matter partial volume estimate for tracking (CMC).
        csf_files : string
            Path to cerebrospinal fluid partial volume estimate for tracking
            (CMC).
        seeding_files : string
            A binary image showing where we need to seed for tracking.
        step_size : float, optional
            Step size (in mm) used for tracking.
        seed_density : int, optional
            Number of seeds per dimension inside voxel.
             For example, seed_density of 2 means 8 regularly distributed
             points in the voxel. And seed density of 1 means 1 point at the
             center of the voxel.
        pmf_threshold : float, optional
            Threshold for ODF functions.
        max_angle : float, optional
            Maximum angle between streamline segments (range [0, 90]).
        pft_back : float, optional
            Distance in mm to back track before starting the particle filtering
            tractography. The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_front : float, optional
            Distance in mm to run the particle filtering tractography after the
            the back track distance. The total particle filtering
            tractography distance is equal to back_tracking_dist +
            front_tracking_dist.
        pft_count : int, optional
            Number of particles to use in the particle filter.
        out_dir : string, optional
           Output directory. (default current directory)
        out_tractogram : string, optional
           Name of the tractogram file to be saved.
        save_seeds : bool, optional
            If true, save the seeds associated to their streamline
            in the 'data_per_streamline' Tractogram dictionary using
            'seeds' as the key.
        min_wm_pve_before_stopping : int, optional
            Minimum white matter pve (1 - stopping_criterion.include_map -
            stopping_criterion.exclude_map) to reach before allowing the
            tractography to stop.
        References
        ----------
        .. footbibliography::
        """
        io_it = self.get_io_iterator()
        for pams_path, wm_path, gm_path, csf_path, seeding_path, out_tract in io_it:
            logging.info(f"Particle Filtering tracking on {pams_path}")
            pam = load_pam(pams_path, verbose=False)
            wm, affine, voxel_size = load_nifti(wm_path, return_voxsize=True)
            gm, _ = load_nifti(gm_path)
            csf, _ = load_nifti(csf_path)
            avs = sum(voxel_size) / len(voxel_size)  # average_voxel_size
            stopping_criterion = CmcStoppingCriterion.from_pve(
                wm, gm, csf, step_size=step_size, average_voxel_size=avs
            )
            logging.info("stopping criterion done")
            seed_mask, _ = load_nifti(seeding_path)
            seeds = utils.seeds_from_mask(
                seed_mask, affine, density=[seed_density, seed_density, seed_density]
            )
            logging.info("seeds done")
            dg = ProbabilisticDirectionGetter
            direction_getter = dg.from_shcoeff(
                pam.shm_coeff,
                max_angle=max_angle,
                sphere=pam.sphere,
                pmf_threshold=pmf_threshold,
            )
            tracking_result = ParticleFilteringTracking(
                direction_getter,
                stopping_criterion,
                seeds,
                affine,
                step_size=step_size,
                pft_back_tracking_dist=pft_back,
                pft_front_tracking_dist=pft_front,
                pft_max_trial=20,
                particle_count=pft_count,
                save_seeds=save_seeds,
                min_wm_pve_before_stopping=min_wm_pve_before_stopping,
            )
            logging.info("ParticleFilteringTracking initiated")
            if save_seeds:
                streamlines, seeds = zip(*tracking_result)
                seeds = {"seeds": seeds}
            else:
                streamlines = list(tracking_result)
                seeds = {}
            sft = StatefulTractogram(
                streamlines, seeding_path, Space.RASMM, data_per_streamline=seeds
            )
            save_tractogram(sft, out_tract, bbox_valid_check=False)
            logging.info(f"Saved {out_tract}")