import logging
from time import time
import numpy as np
from dipy.io.image import load_nifti, save_nifti
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.segment.bundles import RecoBundles
from dipy.segment.mask import median_otsu
from dipy.tracking import Streamlines
from dipy.workflows.utils import handle_vol_idx
from dipy.workflows.workflow import Workflow
[docs]
class RecoBundlesFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "recobundles" 
[docs]
    def run(
        self,
        streamline_files,
        model_bundle_files,
        greater_than=50,
        less_than=1000000,
        no_slr=False,
        clust_thr=15.0,
        reduction_thr=15.0,
        reduction_distance="mdf",
        model_clust_thr=2.5,
        pruning_thr=8.0,
        pruning_distance="mdf",
        slr_metric="symmetric",
        slr_transform="similarity",
        slr_matrix="small",
        refine=False,
        r_reduction_thr=12.0,
        r_pruning_thr=6.0,
        no_r_slr=False,
        out_dir="",
        out_recognized_transf="recognized.trk",
        out_recognized_labels="labels.npy",
    ):
        """Recognize bundles
        See :footcite:p:`Garyfallidis2018` and :footcite:p:`Chandio2020a` for
        further details about the method.
        Parameters
        ----------
        streamline_files : string
            The path of streamline files where you want to recognize bundles.
        model_bundle_files : string
            The path of model bundle files.
        greater_than : int, optional
            Keep streamlines that have length greater than
            this value in mm.
        less_than : int, optional
            Keep streamlines have length less than this value
            in mm.
        no_slr : bool, optional
            Don't enable local Streamline-based Linear
            Registration.
        clust_thr : float, optional
            MDF distance threshold for all streamlines.
        reduction_thr : float, optional
            Reduce search space by (mm).
        reduction_distance : string, optional
            Reduction distance type can be mdf or mam.
        model_clust_thr : float, optional
            MDF distance threshold for the model bundles.
        pruning_thr : float, optional
            Pruning after matching.
        pruning_distance : string, optional
            Pruning distance type can be mdf or mam.
        slr_metric : string, optional
            Options are None, symmetric, asymmetric or diagonal.
        slr_transform : string, optional
            Transformation allowed. translation, rigid, similarity or scaling.
        slr_matrix : string, optional
            Options are 'nano', 'tiny', 'small', 'medium', 'large', 'huge'.
        refine : bool, optional
            Enable refine recognized bundle.
        r_reduction_thr : float, optional
            Refine reduce search space by (mm).
        r_pruning_thr : float, optional
            Refine pruning after matching.
        no_r_slr : bool, optional
            Don't enable Refine local Streamline-based Linear
            Registration.
        out_dir : string, optional
            Output directory. (default current directory)
        out_recognized_transf : string, optional
            Recognized bundle in the space of the model bundle.
        out_recognized_labels : string, optional
            Indices of recognized bundle in the original tractogram.
        References
        ----------
        .. footbibliography::
        """
        slr = not no_slr
        r_slr = not no_r_slr
        bounds = [
            (-30, 30),
            (-30, 30),
            (-30, 30),
            (-45, 45),
            (-45, 45),
            (-45, 45),
            (0.8, 1.2),
            (0.8, 1.2),
            (0.8, 1.2),
        ]
        slr_matrix = slr_matrix.lower()
        if slr_matrix == "nano":
            slr_select = (100, 100)
        if slr_matrix == "tiny":
            slr_select = (250, 250)
        if slr_matrix == "small":
            slr_select = (400, 400)
        if slr_matrix == "medium":
            slr_select = (600, 600)
        if slr_matrix == "large":
            slr_select = (800, 800)
        if slr_matrix == "huge":
            slr_select = (1200, 1200)
        slr_transform = slr_transform.lower()
        if slr_transform == "translation":
            bounds = bounds[:3]
        if slr_transform == "rigid":
            bounds = bounds[:6]
        if slr_transform == "similarity":
            bounds = bounds[:7]
        if slr_transform == "scaling":
            bounds = bounds[:9]
        logging.info("### RecoBundles ###")
        io_it = self.get_io_iterator()
        t = time()
        logging.info(streamline_files)
        input_obj = load_tractogram(streamline_files, "same", bbox_valid_check=False)
        streamlines = input_obj.streamlines
        logging.info(f" Loading time {time() - t:0.3f} sec")
        rb = RecoBundles(streamlines, greater_than=greater_than, less_than=less_than)
        for _, mb, out_rec, out_labels in io_it:
            t = time()
            logging.info(mb)
            model_bundle = load_tractogram(
                mb, "same", bbox_valid_check=False
            ).streamlines
            logging.info(f" Loading time {time() - t:0.3f} sec")
            logging.info("model file = ")
            logging.info(mb)
            recognized_bundle, labels = rb.recognize(
                model_bundle,
                model_clust_thr=model_clust_thr,
                reduction_thr=reduction_thr,
                reduction_distance=reduction_distance,
                pruning_thr=pruning_thr,
                pruning_distance=pruning_distance,
                slr=slr,
                slr_metric=slr_metric,
                slr_x0=slr_transform,
                slr_bounds=bounds,
                slr_select=slr_select,
                slr_method="L-BFGS-B",
            )
            if refine:
                if len(recognized_bundle) > 1:
                    # affine
                    x0 = np.array([0, 0, 0, 0, 0, 0, 1.0, 1.0, 1, 0, 0, 0])
                    affine_bounds = [
                        (-30, 30),
                        (-30, 30),
                        (-30, 30),
                        (-45, 45),
                        (-45, 45),
                        (-45, 45),
                        (0.8, 1.2),
                        (0.8, 1.2),
                        (0.8, 1.2),
                        (-10, 10),
                        (-10, 10),
                        (-10, 10),
                    ]
                    recognized_bundle, labels = rb.refine(
                        model_bundle,
                        recognized_bundle,
                        model_clust_thr=model_clust_thr,
                        reduction_thr=r_reduction_thr,
                        reduction_distance=reduction_distance,
                        pruning_thr=r_pruning_thr,
                        pruning_distance=pruning_distance,
                        slr=r_slr,
                        slr_metric=slr_metric,
                        slr_x0=x0,
                        slr_bounds=affine_bounds,
                        slr_select=slr_select,
                        slr_method="L-BFGS-B",
                    )
            if len(labels) > 0:
                ba, bmd = rb.evaluate_results(
                    model_bundle, recognized_bundle, slr_select
                )
                logging.info(f"Bundle adjacency Metric {ba}")
                logging.info(f"Bundle Min Distance Metric {bmd}")
            new_tractogram = StatefulTractogram(
                recognized_bundle, streamline_files, Space.RASMM
            )
            save_tractogram(new_tractogram, out_rec, bbox_valid_check=False)
            logging.info("Saving output files ...")
            np.save(out_labels, np.array(labels))
            logging.info(out_rec)
            logging.info(out_labels) 
 
[docs]
class LabelsBundlesFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "labelsbundles" 
[docs]
    def run(
        self,
        streamline_files,
        labels_files,
        out_dir="",
        out_bundle="recognized_orig.trk",
    ):
        """Extract bundles using existing indices (labels)
        See :footcite:p:`Garyfallidis2018` for further details about the method.
        Parameters
        ----------
        streamline_files : string
            The path of streamline files where you want to recognize bundles.
        labels_files : string
            The path of model bundle files.
        out_dir : string, optional
            Output directory. (default current directory)
        out_bundle : string, optional
            Recognized bundle in the space of the model bundle.
        References
        ----------
        .. footbibliography::
        """
        logging.info("### Labels to Bundles ###")
        io_it = self.get_io_iterator()
        for f_steamlines, f_labels, out_bundle in io_it:
            logging.info(f_steamlines)
            sft = load_tractogram(f_steamlines, "same", bbox_valid_check=False)
            streamlines = sft.streamlines
            logging.info(f_labels)
            location = np.load(f_labels)
            if len(location) < 1:
                bundle = Streamlines([])
            else:
                bundle = streamlines[location]
            logging.info("Saving output files ...")
            new_sft = StatefulTractogram(bundle, sft, Space.RASMM)
            save_tractogram(new_sft, out_bundle, bbox_valid_check=False)
            logging.info(out_bundle)