from copy import deepcopy
import logging
import os
import time
import nibabel as nib
from nibabel.streamlines import detect_format
from nibabel.streamlines.tractogram import Tractogram
import numpy as np
import trx.trx_file_memmap as tmm
from dipy.io.dpy import Dpy
from dipy.io.stateful_tractogram import Origin, Space, StatefulTractogram
from dipy.io.utils import create_tractogram_header, is_header_compatible
from dipy.io.vtk import load_vtk_streamlines, save_vtk_streamlines
from dipy.testing.decorators import warning_for_keywords
[docs]
@warning_for_keywords()
def save_tractogram(sft, filename, *, bbox_valid_check=True):
"""Save the stateful tractogram in any format (trk/tck/vtk/vtp/fib/dpy)
Parameters
----------
sft : StatefulTractogram
The stateful tractogram to save
filename : string
Filename with valid extension
bbox_valid_check : bool
Verification for negative voxel coordinates or values above the
volume dimensions. Default is True, to enforce valid file.
Returns
-------
output : bool
True if the saving operation was successful
"""
_, extension = os.path.splitext(filename)
if extension not in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
raise TypeError("Output filename is not one of the supported format.")
if bbox_valid_check and not sft.is_bbox_in_vox_valid():
raise ValueError(
"Bounding box is not valid in voxel space, cannot "
"load a valid file if some coordinates are invalid.\n"
"Please set bbox_valid_check to False and then use "
"the function remove_invalid_streamlines to discard "
"invalid streamlines."
)
old_space = deepcopy(sft.space)
old_origin = deepcopy(sft.origin)
sft.to_rasmm()
sft.to_center()
timer = time.time()
if extension in [".trk", ".tck"]:
tractogram_type = detect_format(filename)
header = create_tractogram_header(tractogram_type, *sft.space_attributes)
new_tractogram = Tractogram(sft.streamlines, affine_to_rasmm=np.eye(4))
if extension == ".trk":
new_tractogram.data_per_point = sft.data_per_point
new_tractogram.data_per_streamline = sft.data_per_streamline
fileobj = tractogram_type(new_tractogram, header=header)
nib.streamlines.save(fileobj, filename)
elif extension in [".vtk", ".vtp", ".fib"]:
binary = extension in [".vtk", ".fib"]
save_vtk_streamlines(sft.streamlines, filename, binary=binary)
elif extension in [".dpy"]:
dpy_obj = Dpy(filename, mode="w")
dpy_obj.write_tracks(sft.streamlines)
dpy_obj.close()
elif extension in [".trx"]:
trx = tmm.TrxFile.from_sft(sft)
tmm.save(trx, filename)
trx.close()
logging.debug(
"Save %s with %s streamlines in %s seconds.",
filename,
len(sft),
round(time.time() - timer, 3),
)
sft.to_space(old_space)
sft.to_origin(old_origin)
return True
[docs]
@warning_for_keywords()
def load_tractogram(
filename,
reference,
*,
to_space=Space.RASMM,
to_origin=Origin.NIFTI,
bbox_valid_check=True,
trk_header_check=True,
):
"""Load the stateful tractogram from any format (trk/tck/vtk/vtp/fib/dpy)
Parameters
----------
filename : string
Filename with valid extension
reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
trk.header (dict), or 'same' if the input is a trk file.
Reference that provides the spatial attribute.
Typically a nifti-related object from the native diffusion used for
streamlines generation
to_space : Enum (dipy.io.stateful_tractogram.Space)
Space to which the streamlines will be transformed after loading
to_origin : Enum (dipy.io.stateful_tractogram.Origin)
Origin to which the streamlines will be transformed after loading
NIFTI standard, default (center of the voxel)
TRACKVIS standard (corner of the voxel)
bbox_valid_check : bool
Verification for negative voxel coordinates or values above the
volume dimensions. Default is True, to enforce valid file.
trk_header_check : bool
Verification that the reference has the same header as the spatial
attributes as the input tractogram when a Trk is loaded
Returns
-------
output : StatefulTractogram
The tractogram to load (must have been saved properly)
"""
_, extension = os.path.splitext(filename)
if extension not in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
logging.error("Output filename is not one of the supported format.")
return False
if to_space not in Space:
logging.error("Space MUST be one of the 3 choices (Enum).")
return False
if reference == "same":
if extension in [".trk", ".trx"]:
reference = filename
else:
logging.error(
'Reference must be provided, "same" is only ' "available for Trk file."
)
return False
if trk_header_check and extension == ".trk":
if not is_header_compatible(filename, reference):
logging.error("Trk file header does not match the provided reference.")
return False
timer = time.time()
data_per_point = None
data_per_streamline = None
if extension in [".trk", ".tck"]:
tractogram_obj = nib.streamlines.load(filename).tractogram
streamlines = tractogram_obj.streamlines
if extension == ".trk":
data_per_point = tractogram_obj.data_per_point
data_per_streamline = tractogram_obj.data_per_streamline
elif extension in [".vtk", ".vtp", ".fib"]:
streamlines = load_vtk_streamlines(filename)
elif extension in [".dpy"]:
dpy_obj = Dpy(filename, mode="r")
streamlines = list(dpy_obj.read_tracks())
dpy_obj.close()
if extension in [".trx"]:
trx_obj = tmm.load(filename)
sft = trx_obj.to_sft()
trx_obj.close()
else:
sft = StatefulTractogram(
streamlines,
reference,
Space.RASMM,
origin=Origin.NIFTI,
data_per_point=data_per_point,
data_per_streamline=data_per_streamline,
)
logging.debug(
"Load %s with %s streamlines in %s seconds.",
filename,
len(sft),
round(time.time() - timer, 3),
)
if bbox_valid_check and not sft.is_bbox_in_vox_valid():
raise ValueError(
"Bounding box is not valid in voxel space, cannot "
"load a valid file if some coordinates are invalid.\n"
"Please set bbox_valid_check to False and then use "
"the function remove_invalid_streamlines to discard "
"invalid streamlines."
)
sft.to_space(to_space)
sft.to_origin(to_origin)
return sft
[docs]
def load_generator(ttype):
"""Generate a loading function that performs a file extension
check to restrict the user to a single file format.
Parameters
----------
ttype : string
Extension of the file format that requires a loader
Returns
-------
output : function
Function (load_tractogram) that handle only one file format
"""
@warning_for_keywords()
def f_gen(
filename,
reference,
*,
to_space=Space.RASMM,
to_origin=Origin.NIFTI,
bbox_valid_check=True,
trk_header_check=True,
):
_, extension = os.path.splitext(filename)
if not extension == ttype:
msg = f"This function can only load {ttype} files, "
msg += "for a more general purpose, use load_tractogram instead."
raise ValueError(msg)
sft = load_tractogram(
filename,
reference,
to_space=Space.RASMM,
to_origin=to_origin,
bbox_valid_check=bbox_valid_check,
trk_header_check=trk_header_check,
)
return sft
f_gen.__doc__ = load_tractogram.__doc__.replace(
"from any format (trk/tck/vtk/vtp/fib/dpy)", f"of the {ttype} format"
)
return f_gen
[docs]
def save_generator(ttype):
"""Generate a saving function that performs a file extension
check to restrict the user to a single file format.
Parameters
----------
ttype : string
Extension of the file format that requires a saver
Returns
-------
output : function
Function (save_tractogram) that handle only one file format
"""
def f_gen(sft, filename, bbox_valid_check=True):
_, extension = os.path.splitext(filename)
if not extension == ttype:
msg = f"This function can only save {ttype} file, "
msg += "for more general cases, use save_tractogram instead."
raise ValueError(msg)
save_tractogram(sft, filename, bbox_valid_check=bbox_valid_check)
f_gen.__doc__ = save_tractogram.__doc__.replace(
"in any format (trk/tck/vtk/vtp/fib/dpy)", f"of the {ttype} format"
)
return f_gen
load_trk = load_generator(".trk")
load_tck = load_generator(".tck")
load_trx = load_generator(".trx")
load_vtk = load_generator(".vtk")
load_vtp = load_generator(".vtp")
load_fib = load_generator(".fib")
load_dpy = load_generator(".dpy")
save_trk = save_generator(".trk")
save_tck = save_generator(".tck")
save_trx = save_generator(".trx")
save_vtk = save_generator(".vtk")
save_vtp = save_generator(".vtp")
save_fib = save_generator(".fib")
save_dpy = save_generator(".dpy")