import importlib
from inspect import getmembers, isfunction
import logging
import os
import sys
import warnings
import numpy as np
import trx.trx_file_memmap as tmm
from dipy.core.sphere import Sphere
from dipy.data import get_sphere
from dipy.io.image import load_nifti, save_nifti
from dipy.io.peaks import (
    load_pam,
    niftis_to_pam,
    pam_to_niftis,
    tensor_to_pam,
)
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.reconst.shm import convert_sh_descoteaux_tournier
from dipy.reconst.utils import convert_tensors
from dipy.tracking.streamlinespeed import length
from dipy.utils.tractogram import concatenate_tractogram
from dipy.workflows.workflow import Workflow
[docs]
class IoInfoFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "io_info" 
[docs]
    def run(
        self,
        input_files,
        b0_threshold=50,
        bvecs_tol=0.01,
        bshell_thr=100,
        reference=None,
    ):
        """Provides useful information about different files used in
        medical imaging. Any number of input files can be provided. The
        program identifies the type of file by its extension.
        Parameters
        ----------
        input_files : variable string
            Any number of Nifti1, bvals or bvecs files.
        b0_threshold : float, optional
            Threshold used to find b0 volumes.
        bvecs_tol : float, optional
            Threshold used to check that norm(bvec) = 1 +/- bvecs_tol
            b-vectors are unit vectors.
        bshell_thr : float, optional
            Threshold for distinguishing b-values in different shells.
        reference : string, optional
            Reference anatomy for tck/vtk/fib/dpy file.
            support (.nii or .nii.gz).
        """
        np.set_printoptions(3, suppress=True)
        io_it = self.get_io_iterator()
        for input_path in io_it:
            mult_ = len(input_path)
            logging.info(f"-----------{mult_ * '-'}")
            logging.info(f"Looking at {input_path}")
            logging.info(f"-----------{mult_ * '-'}")
            ipath_lower = input_path.lower()
            extension = os.path.splitext(ipath_lower)[1]
            if ipath_lower.endswith(".nii") or ipath_lower.endswith(".nii.gz"):
                data, affine, img, vox_sz, affcodes = load_nifti(
                    input_path, return_img=True, return_voxsize=True, return_coords=True
                )
                logging.info(f"Data size {data.shape}")
                logging.info(f"Data type {data.dtype}")
                if data.ndim == 3:
                    logging.info(
                        f"Data min {data.min()} max {data.max()} avg {data.mean()}"
                    )
                    logging.info(
                        f"2nd percentile {np.percentile(data, 2)} "
                        f"98th percentile {np.percentile(data, 98)}"
                    )
                if data.ndim == 4:
                    logging.info(
                        f"Data min {data[..., 0].min()} "
                        f"max {data[..., 0].max()} "
                        f"avg {data[..., 0].mean()} of vol 0"
                    )
                    msg = (
                        f"2nd percentile {np.percentile(data[..., 0], 2)} "
                        f"98th percentile {np.percentile(data[..., 0], 98)} "
                        f"of vol 0"
                    )
                    logging.info(msg)
                logging.info(f"Native coordinate system {''.join(affcodes)}")
                logging.info(f"Affine Native to RAS matrix \n{affine}")
                logging.info(f"Voxel size {np.array(vox_sz)}")
                if np.sum(np.abs(np.diff(vox_sz))) > 0.1:
                    msg = "Voxel size is not isotropic. Please reslice.\n"
                    logging.warning(msg, stacklevel=2)
            if os.path.basename(input_path).lower().find("bval") > -1:
                bvals = np.loadtxt(input_path)
                logging.info(f"b-values \n{bvals}")
                logging.info(f"Total number of b-values {len(bvals)}")
                shells = np.sum(np.diff(np.sort(bvals)) > bshell_thr)
                logging.info(f"Number of gradient shells {shells}")
                logging.info(
                    f"Number of b0s {np.sum(bvals <= b0_threshold)} "
                    f"(b0_thr {b0_threshold})\n"
                )
            if os.path.basename(input_path).lower().find("bvec") > -1:
                bvecs = np.loadtxt(input_path)
                logging.info(f"Bvectors shape on disk is {bvecs.shape}")
                rows, cols = bvecs.shape
                if rows < cols:
                    bvecs = bvecs.T
                logging.info(f"Bvectors are \n{bvecs}")
                norms = np.array([np.linalg.norm(bvec) for bvec in bvecs])
                res = np.where((norms <= 1 + bvecs_tol) & (norms >= 1 - bvecs_tol))
                ncl1 = np.sum(norms < 1 - bvecs_tol)
                logging.info(f"Total number of unit bvectors {len(res[0])}")
                logging.info(f"Total number of non-unit bvectors {ncl1}\n")
            if extension in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
                sft = None
                if extension in [".trk", ".trx"]:
                    sft = load_tractogram(input_path, "same", bbox_valid_check=False)
                else:
                    sft = load_tractogram(input_path, reference, bbox_valid_check=False)
                lengths_mm = list(length(sft.streamlines))
                sft.to_voxmm()
                lengths, steps = [], []
                for streamline in sft.streamlines:
                    lengths += [len(streamline)]
                    steps += [np.sqrt(np.sum(np.diff(streamline, axis=0) ** 2, axis=1))]
                steps = np.hstack(steps)
                logging.info(f"Number of streamlines: {len(sft)}")
                logging.info(f"min_length_mm: {float(np.min(lengths_mm))}")
                logging.info(f"mean_length_mm: {float(np.mean(lengths_mm))}")
                logging.info(f"max_length_mm: {float(np.max(lengths_mm))}")
                logging.info(f"std_length_mm: {float(np.std(lengths_mm))}")
                logging.info(f"min_length_nb_points: {float(np.min(lengths))}")
                logging.info("mean_length_nb_points: " f"{float(np.mean(lengths))}")
                logging.info(f"max_length_nb_points: {float(np.max(lengths))}")
                logging.info(f"std_length_nb_points: {float(np.std(lengths))}")
                logging.info(f"min_step_size: {float(np.min(steps))}")
                logging.info(f"mean_step_size: {float(np.mean(steps))}")
                logging.info(f"max_step_size: {float(np.max(steps))}")
                logging.info(f"std_step_size: {float(np.std(steps))}")
                logging.info(
                    "data_per_point_keys: " f"{list(sft.data_per_point.keys())}"
                )
                logging.info(
                    "data_per_streamline_keys: "
                    f"{list(sft.data_per_streamline.keys())}"
                )
        np.set_printoptions() 
 
[docs]
class FetchFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "fetch" 
[docs]
    @staticmethod
    def get_fetcher_datanames():
        """Gets available dataset and function names.
        Returns
        -------
        available_data: dict
            Available dataset and function names.
        """
        fetcher_module = FetchFlow.load_module("dipy.data.fetcher")
        available_data = dict(
            {
                (name.replace("fetch_", ""), func)
                for name, func in getmembers(fetcher_module, isfunction)
                if name.lower().startswith("fetch_")
                and func is not fetcher_module.fetch_data
                if name.lower() not in ["fetch_hbn", "fetch_hcp"]
            }
        )
        return available_data 
[docs]
    @staticmethod
    def load_module(module_path):
        """Load / reload an external module.
        Parameters
        ----------
        module_path: string
            the path to the module relative to the main script
        Returns
        -------
        module: module object
        """
        if module_path in sys.modules:
            return importlib.reload(sys.modules[module_path])
        else:
            return importlib.import_module(module_path) 
[docs]
    def run(self, data_names, out_dir=""):
        """Download files to folder and check their md5 checksums.
        To see all available datasets, please type "list" in data_names.
        Parameters
        ----------
        data_names : variable string
            Any number of Nifti1, bvals or bvecs files.
        out_dir : string, optional
            Output directory. (default current directory)
        """
        if out_dir:
            dipy_home = os.environ.get("DIPY_HOME", None)
            os.environ["DIPY_HOME"] = out_dir
        available_data = FetchFlow.get_fetcher_datanames()
        data_names = [name.lower() for name in data_names]
        if "all" in data_names:
            for name, fetcher_function in available_data.items():
                logging.info("------------------------------------------")
                logging.info(f"Fetching at {name}")
                logging.info("------------------------------------------")
                fetcher_function()
        elif "list" in data_names:
            logging.info(
                "Please, select between the following data names: "
                f"{', '.join(available_data.keys())}"
            )
        else:
            skipped_names = []
            for data_name in data_names:
                if data_name not in available_data.keys():
                    skipped_names.append(data_name)
                    continue
                logging.info("------------------------------------------")
                logging.info(f"Fetching at {data_name}")
                logging.info("------------------------------------------")
                available_data[data_name]()
            nb_success = len(data_names) - len(skipped_names)
            print("\n")
            logging.info(f"Fetched {nb_success} / {len(data_names)} Files ")
            if skipped_names:
                logging.warn(f"Skipped data name(s): {' '.join(skipped_names)}")
                logging.warn(
                    "Please, select between the following data names: "
                    f"{', '.join(available_data.keys())}"
                )
        if out_dir:
            if dipy_home:
                os.environ["DIPY_HOME"] = dipy_home
            else:
                os.environ.pop("DIPY_HOME", None)
            # We load the module again so that if we run another one of these
            # in the same process, we don't have the env variable pointing
            # to the wrong place
            self.load_module("dipy.data.fetcher") 
 
[docs]
class SplitFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "split" 
[docs]
    def run(self, input_files, vol_idx=0, out_dir="", out_split="split.nii.gz"):
        """Splits the input 4D file and extracts the required 3D volume.
        Parameters
        ----------
        input_files : variable string
            Any number of Nifti1 files
        vol_idx : int, optional
            Index of the 3D volume to extract.
        out_dir : string, optional
            Output directory. (default current directory)
        out_split : string, optional
            Name of the resulting split volume
        """
        io_it = self.get_io_iterator()
        for fpath, osplit in io_it:
            logging.info(f"Splitting {fpath}")
            data, affine, image = load_nifti(fpath, return_img=True)
            if vol_idx == 0:
                logging.info("Splitting and extracting 1st b0")
            split_vol = data[..., vol_idx]
            save_nifti(osplit, split_vol, affine, hdr=image.header)
            logging.info(f"Split volume saved as {osplit}") 
 
[docs]
class ConcatenateTractogramFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "concatracks" 
[docs]
    def run(
        self,
        tractogram_files,
        reference=None,
        delete_dpv=False,
        delete_dps=False,
        delete_groups=False,
        check_space_attributes=True,
        preallocation=False,
        out_dir="",
        out_extension="trx",
        out_tractogram="concatenated_tractogram",
    ):
        """Concatenate multiple tractograms into one.
        Parameters
        ----------
        tractogram_list : variable string
            The stateful tractogram filenames to concatenate
        reference : string, optional
            Reference anatomy for tck/vtk/fib/dpy file.
            support (.nii or .nii.gz).
        delete_dpv : bool, optional
            Delete dpv keys that do not exist in all the provided TrxFiles
        delete_dps : bool, optional
            Delete dps keys that do not exist in all the provided TrxFile
        delete_groups : bool, optional
            Delete all the groups that currently exist in the TrxFiles
        check_space_attributes : bool, optional
            Verify that dimensions and size of data are similar between all the
            TrxFiles
        preallocation : bool, optional
            Preallocated TrxFile has already been generated and is the first
            element in trx_list (Note: delete_groups must be set to True as
            well)
        out_dir : string, optional
            Output directory. (default current directory)
        out_extension : string, optional
            Extension of the resulting tractogram
        out_tractogram : string, optional
            Name of the resulting tractogram
        """
        io_it = self.get_io_iterator()
        trx_list = []
        has_group = False
        for fpath, _, _ in io_it:
            if fpath.lower().endswith(".trx") or fpath.lower().endswith(".trk"):
                reference = "same"
            if not reference:
                raise ValueError(
                    "No reference provided. It is needed for tck,"
                    "fib, dpy or vtk files"
                )
            tractogram_obj = load_tractogram(fpath, reference, bbox_valid_check=False)
            if not isinstance(tractogram_obj, tmm.TrxFile):
                tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj)
            elif len(tractogram_obj.groups):
                has_group = True
            trx_list.append(tractogram_obj)
        trx = concatenate_tractogram(
            trx_list,
            delete_dpv=delete_dpv,
            delete_dps=delete_dps,
            delete_groups=delete_groups or not has_group,
            check_space_attributes=check_space_attributes,
            preallocation=preallocation,
        )
        valid_extensions = ["trk", "trx", "tck", "fib", "dpy", "vtk"]
        if out_extension.lower() not in valid_extensions:
            raise ValueError(
                f"Invalid extension. Valid extensions are: {valid_extensions}"
            )
        out_fpath = os.path.join(out_dir, f"{out_tractogram}.{out_extension}")
        save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False) 
 
[docs]
class ConvertSHFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "convert_dipy_mrtrix" 
[docs]
    def run(
        self,
        input_files,
        out_dir="",
        out_file="sh_convert_dipy_mrtrix_out.nii.gz",
    ):
        """Converts SH basis representation between DIPY and MRtrix3 formats.
        Because this conversion is equal to its own inverse, it can be used to
        convert in either direction: DIPY to MRtrix3 or vice versa.
        Parameters
        ----------
        input_files : string
            Path to the input files. This path may contain wildcards to
            process multiple inputs at once.
        out_dir : string, optional
            Where the resulting file will be saved. (default '')
        out_file : string, optional
            Name of the result file to be saved.
            (default 'sh_convert_dipy_mrtrix_out.nii.gz')
        """
        io_it = self.get_io_iterator()
        for in_file, out_file in io_it:
            data, affine, image = load_nifti(in_file, return_img=True)
            data = convert_sh_descoteaux_tournier(data)
            save_nifti(out_file, data, affine, hdr=image.header) 
 
[docs]
class ConvertTensorsFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "convert_tensors" 
[docs]
    def run(
        self,
        tensor_files,
        from_format="mrtrix",
        to_format="dipy",
        out_dir=".",
        out_tensor="converted_tensor",
    ):
        """Converts tensor representation between different formats.
        Parameters
        ----------
        tensor_files : variable string
            Any number of tensor files
        from_format : string, optional
            Format of the input tensor files. Valid options are 'dipy',
            'mrtrix', 'ants', 'fsl'.
        to_format : string, optional
            Format of the output tensor files. Valid options are 'dipy',
            'mrtrix', 'ants', 'fsl'.
        out_dir : string, optional
            Output directory. (default current directory)
        out_tensor : string, optional
            Name of the resulting tensor file
        """
        io_it = self.get_io_iterator()
        for fpath, otensor in io_it:
            logging.info(f"Converting {fpath}")
            data, affine, image = load_nifti(fpath, return_img=True)
            data = convert_tensors(data, from_format, to_format)
            save_nifti(otensor, data, affine, hdr=image.header) 
 
[docs]
class ConvertTractogramFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "convert_tractogram" 
[docs]
    def run(
        self,
        input_files,
        reference=None,
        pos_dtype="float32",
        offsets_dtype="uint32",
        out_dir="",
        out_tractogram="converted_tractogram.trk",
    ):
        """Converts tractogram between different formats.
        Parameters
        ----------
        input_files : variable string
            Any number of tractogram files
        reference : string, optional
            Reference anatomy for tck/vtk/fib/dpy file.
            support (.nii or .nii.gz).
        pos_dtype : string, optional
            Data type of the tractogram points, used for vtk files.
        offsets_dtype : string, optional
            Data type of the tractogram offsets, used for vtk files.
        out_dir : string, optional
            Output directory. (default current directory)
        out_tractogram : string, optional
            Name of the resulting tractogram
        """
        io_it = self.get_io_iterator()
        for fpath, otracks in io_it:
            in_extension = fpath.lower().split(".")[-1]
            out_extension = otracks.lower().split(".")[-1]
            if in_extension == out_extension:
                warnings.warn(
                    "Input and output are the same file format. Skipping...",
                    stacklevel=2,
                )
                continue
            if not reference and in_extension in ["trx", "trk"]:
                reference = "same"
            if not reference and in_extension not in ["trx", "trk"]:
                raise ValueError(
                    "No reference provided. It is needed for tck,"
                    "fib, dpy or vtk files"
                )
            sft = load_tractogram(fpath, reference, bbox_valid_check=False)
            if out_extension != "trx":
                if out_extension == "vtk":
                    if sft.streamlines._data.dtype.name != pos_dtype:
                        sft.streamlines._data = sft.streamlines._data.astype(pos_dtype)
                    if offsets_dtype == "uint64" or offsets_dtype == "uint32":
                        offsets_dtype = offsets_dtype[1:]
                    if sft.streamlines._offsets.dtype.name != offsets_dtype:
                        sft.streamlines._offsets = sft.streamlines._offsets.astype(
                            offsets_dtype
                        )
                save_tractogram(sft, otracks, bbox_valid_check=False)
            else:
                trx = tmm.TrxFile.from_sft(sft)
                if trx.streamlines._data.dtype.name != pos_dtype:
                    trx.streamlines._data = trx.streamlines._data.astype(pos_dtype)
                if trx.streamlines._offsets.dtype.name != offsets_dtype:
                    trx.streamlines._offsets = trx.streamlines._offsets.astype(
                        offsets_dtype
                    )
                tmm.save(trx, otracks)
                trx.close() 
 
[docs]
class NiftisToPamFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "niftis_to_pam" 
[docs]
    def run(
        self,
        peaks_dir_files,
        peaks_values_files,
        peaks_indices_files,
        shm_files=None,
        gfa_files=None,
        sphere_files=None,
        default_sphere_name="repulsion724",
        out_dir="",
        out_pam="peaks.pam5",
    ):
        """Convert multiple nifti files to a single pam5 file.
        Parameters
        ----------
        peaks_dir_files : string
            Path to the input peaks directions volume. This path may contain
            wildcards to process multiple inputs at once.
        peaks_values_files : string
            Path to the input peaks values volume. This path may contain
            wildcards to process multiple inputs at once.
        peaks_indices_files : string
            Path to the input peaks indices volume. This path may contain
            wildcards to process multiple inputs at once.
        shm_files : string, optional
            Path to the input spherical harmonics volume. This path may
            contain wildcards to process multiple inputs at once.
        gfa_files : string, optional
            Path to the input generalized FA volume. This path may contain
            wildcards to process multiple inputs at once.
        sphere_files : string, optional
            Path to the input sphere vertices. This path may contain
            wildcards to process multiple inputs at once. If it is not define,
            default_sphere option will be used.
        default_sphere_name : string, optional
            Specify default sphere to use for spherical harmonics
            representation. This option can be superseded by
            sphere_files option. Possible options: ['symmetric362', 'symmetric642',
            'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'].
        out_dir : string, optional
            Output directory (default input file directory).
        out_pam : string, optional
            Name of the peaks volume to be saved.
        """
        io_it = self.get_io_iterator()
        msg = f"pam5 files saved in {out_dir or 'current directory'}"
        for fpeak_dirs, fpeak_values, fpeak_indices, opam in io_it:
            logging.info("Converting nifti files to pam5")
            peak_dirs, affine = load_nifti(fpeak_dirs)
            peak_values, _ = load_nifti(fpeak_values)
            peak_indices, _ = load_nifti(fpeak_indices)
            if sphere_files:
                xyz = np.loadtxt(sphere_files)
                sphere = Sphere(xyz=xyz)
            else:
                sphere = get_sphere(name=default_sphere_name)
            niftis_to_pam(
                affine=affine,
                peak_dirs=peak_dirs,
                sphere=sphere,
                peak_values=peak_values,
                peak_indices=peak_indices,
                pam_file=opam,
            )
            logging.info(msg.replace("pam5", opam)) 
 
[docs]
class TensorToPamFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "tensor_to_niftis" 
[docs]
    def run(
        self,
        evals_files,
        evecs_files,
        sphere_files=None,
        default_sphere_name="repulsion724",
        out_dir="",
        out_pam="peaks.pam5",
    ):
        """Convert multiple tensor files(evals, evecs) to pam5 files.
        Parameters
        ----------
        evals_files : string
            Path to the input eigen values volumes. This path may contain
            wildcards to process multiple inputs at once.
        evecs_files : string
            Path to the input eigen vectors volumes. This path may contain
            wildcards to process multiple inputs at once.
        sphere_files : string, optional
            Path to the input sphere vertices. This path may contain
            wildcards to process multiple inputs at once. If it is not define,
            default_sphere option will be used.
        default_sphere_name : string, optional
            Specify default sphere to use for spherical harmonics
            representation. This option can be superseded by sphere_files
            option. Possible options: ['symmetric362', 'symmetric642',
            'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'].
        out_dir : string, optional
            Output directory (default input file directory).
        out_pam : string, optional
            Name of the peaks volume to be saved.
        """
        io_it = self.get_io_iterator()
        msg = f"pam5 files saved in {out_dir or 'current directory'}"
        for fevals, fevecs, opam in io_it:
            logging.info("Converting tensor files to pam5...")
            evals, affine = load_nifti(fevals)
            evecs, _ = load_nifti(fevecs)
            if sphere_files:
                xyz = np.loadtxt(sphere_files)
                sphere = Sphere(xyz=xyz)
            else:
                sphere = get_sphere(name=default_sphere_name)
            tensor_to_pam(evals, evecs, affine, sphere=sphere, pam_file=opam)
            logging.info(msg.replace("pam5", opam)) 
 
[docs]
class PamToNiftisFlow(Workflow):
[docs]
    @classmethod
    def get_short_name(cls):
        return "pam_to_niftis" 
[docs]
    def run(
        self,
        pam_files,
        out_dir="",
        out_peaks_dir="peaks_dirs.nii.gz",
        out_peaks_values="peaks_values.nii.gz",
        out_peaks_indices="peaks_indices.nii.gz",
        out_shm="shm.nii.gz",
        out_gfa="gfa.nii.gz",
        out_sphere="sphere.txt",
        out_b="B.nii.gz",
        out_qa="qa.nii.gz",
    ):
        """Convert pam5 files to multiple nifti files.
        Parameters
        ----------
        pam_files : string
            Path to the input peaks volumes. This path may contain wildcards to
            process multiple inputs at once.
        out_dir : string, optional
            Output directory (default input file directory).
        out_peaks_dir : string, optional
            Name of the peaks directions volume to be saved.
        out_peaks_values : string, optional
            Name of the peaks values volume to be saved.
        out_peaks_indices : string, optional
            Name of the peaks indices volume to be saved.
        out_shm : string, optional
            Name of the spherical harmonics volume to be saved.
        out_gfa : string, optional
            Generalized FA volume name to be saved.
        out_sphere : string, optional
            Sphere vertices name to be saved.
        out_b : string, optional
            Name of the B Matrix to be saved.
        out_qa : string, optional
            Name of the Quantitative Anisotropy file to be saved.
        """
        io_it = self.get_io_iterator()
        msg = f"Nifti files saved in {out_dir or 'current directory'}"
        for (
            ipam,
            opeaks_dir,
            opeaks_values,
            opeaks_indices,
            oshm,
            ogfa,
            osphere,
            ob,
            oqa,
        ) in io_it:
            logging.info("Converting %s file to niftis...", ipam)
            pam = load_pam(ipam)
            pam_to_niftis(
                pam,
                fname_peaks_dir=opeaks_dir,
                fname_shm=oshm,
                fname_peaks_values=opeaks_values,
                fname_peaks_indices=opeaks_indices,
                fname_sphere=osphere,
                fname_gfa=ogfa,
                fname_b=ob,
                fname_qa=oqa,
            )
            logging.info(msg)