from copy import deepcopy
import logging
import os
import time
import nibabel as nib
from nibabel.streamlines import detect_format
from nibabel.streamlines.tractogram import Tractogram
import numpy as np
import trx.trx_file_memmap as tmm
from dipy.io.dpy import Dpy
from dipy.io.stateful_tractogram import Origin, Space, StatefulTractogram
from dipy.io.utils import create_tractogram_header, is_header_compatible
from dipy.io.vtk import load_vtk_streamlines, save_vtk_streamlines
from dipy.testing.decorators import warning_for_keywords
[docs]
@warning_for_keywords()
def save_tractogram(sft, filename, *, bbox_valid_check=True):
    """Save the stateful tractogram in any format (trk/tck/vtk/vtp/fib/dpy)
    Parameters
    ----------
    sft : StatefulTractogram
        The stateful tractogram to save
    filename : string
        Filename with valid extension
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.
    Returns
    -------
    output : bool
        True if the saving operation was successful
    """
    _, extension = os.path.splitext(filename)
    if extension not in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
        raise TypeError("Output filename is not one of the supported format.")
    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError(
            "Bounding box is not valid in voxel space, cannot "
            "load a valid file if some coordinates are invalid.\n"
            "Please set bbox_valid_check to False and then use "
            "the function remove_invalid_streamlines to discard "
            "invalid streamlines."
        )
    old_space = deepcopy(sft.space)
    old_origin = deepcopy(sft.origin)
    sft.to_rasmm()
    sft.to_center()
    timer = time.time()
    if extension in [".trk", ".tck"]:
        tractogram_type = detect_format(filename)
        header = create_tractogram_header(tractogram_type, *sft.space_attributes)
        new_tractogram = Tractogram(sft.streamlines, affine_to_rasmm=np.eye(4))
        if extension == ".trk":
            new_tractogram.data_per_point = sft.data_per_point
            new_tractogram.data_per_streamline = sft.data_per_streamline
        fileobj = tractogram_type(new_tractogram, header=header)
        nib.streamlines.save(fileobj, filename)
    elif extension in [".vtk", ".vtp", ".fib"]:
        binary = extension in [".vtk", ".fib"]
        save_vtk_streamlines(sft.streamlines, filename, binary=binary)
    elif extension in [".dpy"]:
        dpy_obj = Dpy(filename, mode="w")
        dpy_obj.write_tracks(sft.streamlines)
        dpy_obj.close()
    elif extension in [".trx"]:
        trx = tmm.TrxFile.from_sft(sft)
        tmm.save(trx, filename)
        trx.close()
    logging.debug(
        "Save %s with %s streamlines in %s seconds.",
        filename,
        len(sft),
        round(time.time() - timer, 3),
    )
    sft.to_space(old_space)
    sft.to_origin(old_origin)
    return True 
[docs]
@warning_for_keywords()
def load_tractogram(
    filename,
    reference,
    *,
    to_space=Space.RASMM,
    to_origin=Origin.NIFTI,
    bbox_valid_check=True,
    trk_header_check=True,
):
    """Load the stateful tractogram from any format (trk/tck/vtk/vtp/fib/dpy)
    Parameters
    ----------
    filename : string
        Filename with valid extension
    reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
        trk.header (dict), or 'same' if the input is a trk file.
        Reference that provides the spatial attribute.
        Typically a nifti-related object from the native diffusion used for
        streamlines generation
    to_space : Enum (dipy.io.stateful_tractogram.Space)
        Space to which the streamlines will be transformed after loading
    to_origin : Enum (dipy.io.stateful_tractogram.Origin)
        Origin to which the streamlines will be transformed after loading
            NIFTI standard, default (center of the voxel)
            TRACKVIS standard (corner of the voxel)
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.
    trk_header_check : bool
        Verification that the reference has the same header as the spatial
        attributes as the input tractogram when a Trk is loaded
    Returns
    -------
    output : StatefulTractogram
        The tractogram to load (must have been saved properly)
    """
    _, extension = os.path.splitext(filename)
    if extension not in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
        logging.error("Output filename is not one of the supported format.")
        return False
    if to_space not in Space:
        logging.error("Space MUST be one of the 3 choices (Enum).")
        return False
    if reference == "same":
        if extension in [".trk", ".trx"]:
            reference = filename
        else:
            logging.error(
                'Reference must be provided, "same" is only ' "available for Trk file."
            )
            return False
    if trk_header_check and extension == ".trk":
        if not is_header_compatible(filename, reference):
            logging.error("Trk file header does not match the provided reference.")
            return False
    timer = time.time()
    data_per_point = None
    data_per_streamline = None
    if extension in [".trk", ".tck"]:
        tractogram_obj = nib.streamlines.load(filename).tractogram
        streamlines = tractogram_obj.streamlines
        if extension == ".trk":
            data_per_point = tractogram_obj.data_per_point
            data_per_streamline = tractogram_obj.data_per_streamline
    elif extension in [".vtk", ".vtp", ".fib"]:
        streamlines = load_vtk_streamlines(filename)
    elif extension in [".dpy"]:
        dpy_obj = Dpy(filename, mode="r")
        streamlines = list(dpy_obj.read_tracks())
        dpy_obj.close()
    if extension in [".trx"]:
        trx_obj = tmm.load(filename)
        sft = trx_obj.to_sft()
        trx_obj.close()
    else:
        sft = StatefulTractogram(
            streamlines,
            reference,
            Space.RASMM,
            origin=Origin.NIFTI,
            data_per_point=data_per_point,
            data_per_streamline=data_per_streamline,
        )
    logging.debug(
        "Load %s with %s streamlines in %s seconds.",
        filename,
        len(sft),
        round(time.time() - timer, 3),
    )
    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError(
            "Bounding box is not valid in voxel space, cannot "
            "load a valid file if some coordinates are invalid.\n"
            "Please set bbox_valid_check to False and then use "
            "the function remove_invalid_streamlines to discard "
            "invalid streamlines."
        )
    sft.to_space(to_space)
    sft.to_origin(to_origin)
    return sft 
[docs]
def load_generator(ttype):
    """Generate a loading function that performs a file extension
    check to restrict the user to a single file format.
    Parameters
    ----------
    ttype : string
        Extension of the file format that requires a loader
    Returns
    -------
    output : function
        Function (load_tractogram) that handle only one file format
    """
    @warning_for_keywords()
    def f_gen(
        filename,
        reference,
        *,
        to_space=Space.RASMM,
        to_origin=Origin.NIFTI,
        bbox_valid_check=True,
        trk_header_check=True,
    ):
        _, extension = os.path.splitext(filename)
        if not extension == ttype:
            msg = f"This function can only load {ttype} files, "
            msg += "for a more general purpose, use load_tractogram instead."
            raise ValueError(msg)
        sft = load_tractogram(
            filename,
            reference,
            to_space=Space.RASMM,
            to_origin=to_origin,
            bbox_valid_check=bbox_valid_check,
            trk_header_check=trk_header_check,
        )
        return sft
    f_gen.__doc__ = load_tractogram.__doc__.replace(
        "from any format (trk/tck/vtk/vtp/fib/dpy)", f"of the {ttype} format"
    )
    return f_gen 
[docs]
def save_generator(ttype):
    """Generate a saving function that performs a file extension
    check to restrict the user to a single file format.
    Parameters
    ----------
    ttype : string
        Extension of the file format that requires a saver
    Returns
    -------
    output : function
        Function (save_tractogram) that handle only one file format
    """
    def f_gen(sft, filename, bbox_valid_check=True):
        _, extension = os.path.splitext(filename)
        if not extension == ttype:
            msg = f"This function can only save {ttype} file, "
            msg += "for more general cases, use save_tractogram instead."
            raise ValueError(msg)
        save_tractogram(sft, filename, bbox_valid_check=bbox_valid_check)
    f_gen.__doc__ = save_tractogram.__doc__.replace(
        "in any format (trk/tck/vtk/vtp/fib/dpy)", f"of the {ttype} format"
    )
    return f_gen 
load_trk = load_generator(".trk")
load_tck = load_generator(".tck")
load_trx = load_generator(".trx")
load_vtk = load_generator(".vtk")
load_vtp = load_generator(".vtp")
load_fib = load_generator(".fib")
load_dpy = load_generator(".dpy")
save_trk = save_generator(".trk")
save_tck = save_generator(".tck")
save_trx = save_generator(".trx")
save_vtk = save_generator(".vtk")
save_vtp = save_generator(".vtp")
save_fib = save_generator(".fib")
save_dpy = save_generator(".dpy")