Source code for dipy.workflows.io

import importlib
from inspect import getmembers, isfunction
import logging
import os
import sys
import warnings

import numpy as np
import trx.trx_file_memmap as tmm

from dipy.core.sphere import Sphere
from dipy.data import get_sphere
from dipy.io.image import load_nifti, save_nifti
from dipy.io.peaks import (
    load_pam,
    niftis_to_pam,
    pam_to_niftis,
    tensor_to_pam,
)
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.reconst.shm import convert_sh_descoteaux_tournier
from dipy.reconst.utils import convert_tensors
from dipy.tracking.streamlinespeed import length
from dipy.utils.tractogram import concatenate_tractogram
from dipy.workflows.workflow import Workflow


[docs] class IoInfoFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "io_info"
[docs] def run( self, input_files, b0_threshold=50, bvecs_tol=0.01, bshell_thr=100, reference=None, ): """Provides useful information about different files used in medical imaging. Any number of input files can be provided. The program identifies the type of file by its extension. Parameters ---------- input_files : variable string Any number of Nifti1, bvals or bvecs files. b0_threshold : float, optional Threshold used to find b0 volumes. bvecs_tol : float, optional Threshold used to check that norm(bvec) = 1 +/- bvecs_tol b-vectors are unit vectors. bshell_thr : float, optional Threshold for distinguishing b-values in different shells. reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). """ np.set_printoptions(3, suppress=True) io_it = self.get_io_iterator() for input_path in io_it: mult_ = len(input_path) logging.info(f"-----------{mult_ * '-'}") logging.info(f"Looking at {input_path}") logging.info(f"-----------{mult_ * '-'}") ipath_lower = input_path.lower() extension = os.path.splitext(ipath_lower)[1] if ipath_lower.endswith(".nii") or ipath_lower.endswith(".nii.gz"): data, affine, img, vox_sz, affcodes = load_nifti( input_path, return_img=True, return_voxsize=True, return_coords=True ) logging.info(f"Data size {data.shape}") logging.info(f"Data type {data.dtype}") if data.ndim == 3: logging.info( f"Data min {data.min()} max {data.max()} avg {data.mean()}" ) logging.info( f"2nd percentile {np.percentile(data, 2)} " f"98th percentile {np.percentile(data, 98)}" ) if data.ndim == 4: logging.info( f"Data min {data[..., 0].min()} " f"max {data[..., 0].max()} " f"avg {data[..., 0].mean()} of vol 0" ) msg = ( f"2nd percentile {np.percentile(data[..., 0], 2)} " f"98th percentile {np.percentile(data[..., 0], 98)} " f"of vol 0" ) logging.info(msg) logging.info(f"Native coordinate system {''.join(affcodes)}") logging.info(f"Affine Native to RAS matrix \n{affine}") logging.info(f"Voxel size {np.array(vox_sz)}") if np.sum(np.abs(np.diff(vox_sz))) > 0.1: msg = "Voxel size is not isotropic. Please reslice.\n" logging.warning(msg, stacklevel=2) if os.path.basename(input_path).lower().find("bval") > -1: bvals = np.loadtxt(input_path) logging.info(f"b-values \n{bvals}") logging.info(f"Total number of b-values {len(bvals)}") shells = np.sum(np.diff(np.sort(bvals)) > bshell_thr) logging.info(f"Number of gradient shells {shells}") logging.info( f"Number of b0s {np.sum(bvals <= b0_threshold)} " f"(b0_thr {b0_threshold})\n" ) if os.path.basename(input_path).lower().find("bvec") > -1: bvecs = np.loadtxt(input_path) logging.info(f"Bvectors shape on disk is {bvecs.shape}") rows, cols = bvecs.shape if rows < cols: bvecs = bvecs.T logging.info(f"Bvectors are \n{bvecs}") norms = np.array([np.linalg.norm(bvec) for bvec in bvecs]) res = np.where((norms <= 1 + bvecs_tol) & (norms >= 1 - bvecs_tol)) ncl1 = np.sum(norms < 1 - bvecs_tol) logging.info(f"Total number of unit bvectors {len(res[0])}") logging.info(f"Total number of non-unit bvectors {ncl1}\n") if extension in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]: sft = None if extension in [".trk", ".trx"]: sft = load_tractogram(input_path, "same", bbox_valid_check=False) else: sft = load_tractogram(input_path, reference, bbox_valid_check=False) lengths_mm = list(length(sft.streamlines)) sft.to_voxmm() lengths, steps = [], [] for streamline in sft.streamlines: lengths += [len(streamline)] steps += [np.sqrt(np.sum(np.diff(streamline, axis=0) ** 2, axis=1))] steps = np.hstack(steps) logging.info(f"Number of streamlines: {len(sft)}") logging.info(f"min_length_mm: {float(np.min(lengths_mm))}") logging.info(f"mean_length_mm: {float(np.mean(lengths_mm))}") logging.info(f"max_length_mm: {float(np.max(lengths_mm))}") logging.info(f"std_length_mm: {float(np.std(lengths_mm))}") logging.info(f"min_length_nb_points: {float(np.min(lengths))}") logging.info("mean_length_nb_points: " f"{float(np.mean(lengths))}") logging.info(f"max_length_nb_points: {float(np.max(lengths))}") logging.info(f"std_length_nb_points: {float(np.std(lengths))}") logging.info(f"min_step_size: {float(np.min(steps))}") logging.info(f"mean_step_size: {float(np.mean(steps))}") logging.info(f"max_step_size: {float(np.max(steps))}") logging.info(f"std_step_size: {float(np.std(steps))}") logging.info( "data_per_point_keys: " f"{list(sft.data_per_point.keys())}" ) logging.info( "data_per_streamline_keys: " f"{list(sft.data_per_streamline.keys())}" ) np.set_printoptions()
[docs] class FetchFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "fetch"
[docs] @staticmethod def get_fetcher_datanames(): """Gets available dataset and function names. Returns ------- available_data: dict Available dataset and function names. """ fetcher_module = FetchFlow.load_module("dipy.data.fetcher") available_data = dict( { (name.replace("fetch_", ""), func) for name, func in getmembers(fetcher_module, isfunction) if name.lower().startswith("fetch_") and func is not fetcher_module.fetch_data if name.lower() not in ["fetch_hbn", "fetch_hcp"] } ) return available_data
[docs] @staticmethod def load_module(module_path): """Load / reload an external module. Parameters ---------- module_path: string the path to the module relative to the main script Returns ------- module: module object """ if module_path in sys.modules: return importlib.reload(sys.modules[module_path]) else: return importlib.import_module(module_path)
[docs] def run(self, data_names, out_dir=""): """Download files to folder and check their md5 checksums. To see all available datasets, please type "list" in data_names. Parameters ---------- data_names : variable string Any number of Nifti1, bvals or bvecs files. out_dir : string, optional Output directory. (default current directory) """ if out_dir: dipy_home = os.environ.get("DIPY_HOME", None) os.environ["DIPY_HOME"] = out_dir available_data = FetchFlow.get_fetcher_datanames() data_names = [name.lower() for name in data_names] if "all" in data_names: for name, fetcher_function in available_data.items(): logging.info("------------------------------------------") logging.info(f"Fetching at {name}") logging.info("------------------------------------------") fetcher_function() elif "list" in data_names: logging.info( "Please, select between the following data names: " f"{', '.join(available_data.keys())}" ) else: skipped_names = [] for data_name in data_names: if data_name not in available_data.keys(): skipped_names.append(data_name) continue logging.info("------------------------------------------") logging.info(f"Fetching at {data_name}") logging.info("------------------------------------------") available_data[data_name]() nb_success = len(data_names) - len(skipped_names) print("\n") logging.info(f"Fetched {nb_success} / {len(data_names)} Files ") if skipped_names: logging.warn(f"Skipped data name(s): {' '.join(skipped_names)}") logging.warn( "Please, select between the following data names: " f"{', '.join(available_data.keys())}" ) if out_dir: if dipy_home: os.environ["DIPY_HOME"] = dipy_home else: os.environ.pop("DIPY_HOME", None) # We load the module again so that if we run another one of these # in the same process, we don't have the env variable pointing # to the wrong place self.load_module("dipy.data.fetcher")
[docs] class SplitFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "split"
[docs] def run(self, input_files, vol_idx=0, out_dir="", out_split="split.nii.gz"): """Splits the input 4D file and extracts the required 3D volume. Parameters ---------- input_files : variable string Any number of Nifti1 files vol_idx : int, optional Index of the 3D volume to extract. out_dir : string, optional Output directory. (default current directory) out_split : string, optional Name of the resulting split volume """ io_it = self.get_io_iterator() for fpath, osplit in io_it: logging.info(f"Splitting {fpath}") data, affine, image = load_nifti(fpath, return_img=True) if vol_idx == 0: logging.info("Splitting and extracting 1st b0") split_vol = data[..., vol_idx] save_nifti(osplit, split_vol, affine, hdr=image.header) logging.info(f"Split volume saved as {osplit}")
[docs] class ConcatenateTractogramFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "concatracks"
[docs] def run( self, tractogram_files, reference=None, delete_dpv=False, delete_dps=False, delete_groups=False, check_space_attributes=True, preallocation=False, out_dir="", out_extension="trx", out_tractogram="concatenated_tractogram", ): """Concatenate multiple tractograms into one. Parameters ---------- tractogram_list : variable string The stateful tractogram filenames to concatenate reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). delete_dpv : bool, optional Delete dpv keys that do not exist in all the provided TrxFiles delete_dps : bool, optional Delete dps keys that do not exist in all the provided TrxFile delete_groups : bool, optional Delete all the groups that currently exist in the TrxFiles check_space_attributes : bool, optional Verify that dimensions and size of data are similar between all the TrxFiles preallocation : bool, optional Preallocated TrxFile has already been generated and is the first element in trx_list (Note: delete_groups must be set to True as well) out_dir : string, optional Output directory. (default current directory) out_extension : string, optional Extension of the resulting tractogram out_tractogram : string, optional Name of the resulting tractogram """ io_it = self.get_io_iterator() trx_list = [] has_group = False for fpath, _, _ in io_it: if fpath.lower().endswith(".trx") or fpath.lower().endswith(".trk"): reference = "same" if not reference: raise ValueError( "No reference provided. It is needed for tck," "fib, dpy or vtk files" ) tractogram_obj = load_tractogram(fpath, reference, bbox_valid_check=False) if not isinstance(tractogram_obj, tmm.TrxFile): tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) elif len(tractogram_obj.groups): has_group = True trx_list.append(tractogram_obj) trx = concatenate_tractogram( trx_list, delete_dpv=delete_dpv, delete_dps=delete_dps, delete_groups=delete_groups or not has_group, check_space_attributes=check_space_attributes, preallocation=preallocation, ) valid_extensions = ["trk", "trx", "tck", "fib", "dpy", "vtk"] if out_extension.lower() not in valid_extensions: raise ValueError( f"Invalid extension. Valid extensions are: {valid_extensions}" ) out_fpath = os.path.join(out_dir, f"{out_tractogram}.{out_extension}") save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False)
[docs] class ConvertSHFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_dipy_mrtrix"
[docs] def run( self, input_files, out_dir="", out_file="sh_convert_dipy_mrtrix_out.nii.gz", ): """Converts SH basis representation between DIPY and MRtrix3 formats. Because this conversion is equal to its own inverse, it can be used to convert in either direction: DIPY to MRtrix3 or vice versa. Parameters ---------- input_files : string Path to the input files. This path may contain wildcards to process multiple inputs at once. out_dir : string, optional Where the resulting file will be saved. (default '') out_file : string, optional Name of the result file to be saved. (default 'sh_convert_dipy_mrtrix_out.nii.gz') """ io_it = self.get_io_iterator() for in_file, out_file in io_it: data, affine, image = load_nifti(in_file, return_img=True) data = convert_sh_descoteaux_tournier(data) save_nifti(out_file, data, affine, hdr=image.header)
[docs] class ConvertTensorsFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_tensors"
[docs] def run( self, tensor_files, from_format="mrtrix", to_format="dipy", out_dir=".", out_tensor="converted_tensor", ): """Converts tensor representation between different formats. Parameters ---------- tensor_files : variable string Any number of tensor files from_format : string, optional Format of the input tensor files. Valid options are 'dipy', 'mrtrix', 'ants', 'fsl'. to_format : string, optional Format of the output tensor files. Valid options are 'dipy', 'mrtrix', 'ants', 'fsl'. out_dir : string, optional Output directory. (default current directory) out_tensor : string, optional Name of the resulting tensor file """ io_it = self.get_io_iterator() for fpath, otensor in io_it: logging.info(f"Converting {fpath}") data, affine, image = load_nifti(fpath, return_img=True) data = convert_tensors(data, from_format, to_format) save_nifti(otensor, data, affine, hdr=image.header)
[docs] class ConvertTractogramFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_tractogram"
[docs] def run( self, input_files, reference=None, pos_dtype="float32", offsets_dtype="uint32", out_dir="", out_tractogram="converted_tractogram.trk", ): """Converts tractogram between different formats. Parameters ---------- input_files : variable string Any number of tractogram files reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). pos_dtype : string, optional Data type of the tractogram points, used for vtk files. offsets_dtype : string, optional Data type of the tractogram offsets, used for vtk files. out_dir : string, optional Output directory. (default current directory) out_tractogram : string, optional Name of the resulting tractogram """ io_it = self.get_io_iterator() for fpath, otracks in io_it: in_extension = fpath.lower().split(".")[-1] out_extension = otracks.lower().split(".")[-1] if in_extension == out_extension: warnings.warn( "Input and output are the same file format. Skipping...", stacklevel=2, ) continue if not reference and in_extension in ["trx", "trk"]: reference = "same" if not reference and in_extension not in ["trx", "trk"]: raise ValueError( "No reference provided. It is needed for tck," "fib, dpy or vtk files" ) sft = load_tractogram(fpath, reference, bbox_valid_check=False) if out_extension != "trx": if out_extension == "vtk": if sft.streamlines._data.dtype.name != pos_dtype: sft.streamlines._data = sft.streamlines._data.astype(pos_dtype) if offsets_dtype == "uint64" or offsets_dtype == "uint32": offsets_dtype = offsets_dtype[1:] if sft.streamlines._offsets.dtype.name != offsets_dtype: sft.streamlines._offsets = sft.streamlines._offsets.astype( offsets_dtype ) save_tractogram(sft, otracks, bbox_valid_check=False) else: trx = tmm.TrxFile.from_sft(sft) if trx.streamlines._data.dtype.name != pos_dtype: trx.streamlines._data = trx.streamlines._data.astype(pos_dtype) if trx.streamlines._offsets.dtype.name != offsets_dtype: trx.streamlines._offsets = trx.streamlines._offsets.astype( offsets_dtype ) tmm.save(trx, otracks) trx.close()
[docs] class NiftisToPamFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "niftis_to_pam"
[docs] def run( self, peaks_dir_files, peaks_values_files, peaks_indices_files, shm_files=None, gfa_files=None, sphere_files=None, default_sphere_name="repulsion724", out_dir="", out_pam="peaks.pam5", ): """Convert multiple nifti files to a single pam5 file. Parameters ---------- peaks_dir_files : string Path to the input peaks directions volume. This path may contain wildcards to process multiple inputs at once. peaks_values_files : string Path to the input peaks values volume. This path may contain wildcards to process multiple inputs at once. peaks_indices_files : string Path to the input peaks indices volume. This path may contain wildcards to process multiple inputs at once. shm_files : string, optional Path to the input spherical harmonics volume. This path may contain wildcards to process multiple inputs at once. gfa_files : string, optional Path to the input generalized FA volume. This path may contain wildcards to process multiple inputs at once. sphere_files : string, optional Path to the input sphere vertices. This path may contain wildcards to process multiple inputs at once. If it is not define, default_sphere option will be used. default_sphere_name : string, optional Specify default sphere to use for spherical harmonics representation. This option can be superseded by sphere_files option. Possible options: ['symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200']. out_dir : string, optional Output directory (default input file directory). out_pam : string, optional Name of the peaks volume to be saved. """ io_it = self.get_io_iterator() msg = f"pam5 files saved in {out_dir or 'current directory'}" for fpeak_dirs, fpeak_values, fpeak_indices, opam in io_it: logging.info("Converting nifti files to pam5") peak_dirs, affine = load_nifti(fpeak_dirs) peak_values, _ = load_nifti(fpeak_values) peak_indices, _ = load_nifti(fpeak_indices) if sphere_files: xyz = np.loadtxt(sphere_files) sphere = Sphere(xyz=xyz) else: sphere = get_sphere(name=default_sphere_name) niftis_to_pam( affine=affine, peak_dirs=peak_dirs, sphere=sphere, peak_values=peak_values, peak_indices=peak_indices, pam_file=opam, ) logging.info(msg.replace("pam5", opam))
[docs] class TensorToPamFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "tensor_to_niftis"
[docs] def run( self, evals_files, evecs_files, sphere_files=None, default_sphere_name="repulsion724", out_dir="", out_pam="peaks.pam5", ): """Convert multiple tensor files(evals, evecs) to pam5 files. Parameters ---------- evals_files : string Path to the input eigen values volumes. This path may contain wildcards to process multiple inputs at once. evecs_files : string Path to the input eigen vectors volumes. This path may contain wildcards to process multiple inputs at once. sphere_files : string, optional Path to the input sphere vertices. This path may contain wildcards to process multiple inputs at once. If it is not define, default_sphere option will be used. default_sphere_name : string, optional Specify default sphere to use for spherical harmonics representation. This option can be superseded by sphere_files option. Possible options: ['symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200']. out_dir : string, optional Output directory (default input file directory). out_pam : string, optional Name of the peaks volume to be saved. """ io_it = self.get_io_iterator() msg = f"pam5 files saved in {out_dir or 'current directory'}" for fevals, fevecs, opam in io_it: logging.info("Converting tensor files to pam5...") evals, affine = load_nifti(fevals) evecs, _ = load_nifti(fevecs) if sphere_files: xyz = np.loadtxt(sphere_files) sphere = Sphere(xyz=xyz) else: sphere = get_sphere(name=default_sphere_name) tensor_to_pam(evals, evecs, affine, sphere=sphere, pam_file=opam) logging.info(msg.replace("pam5", opam))
[docs] class PamToNiftisFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "pam_to_niftis"
[docs] def run( self, pam_files, out_dir="", out_peaks_dir="peaks_dirs.nii.gz", out_peaks_values="peaks_values.nii.gz", out_peaks_indices="peaks_indices.nii.gz", out_shm="shm.nii.gz", out_gfa="gfa.nii.gz", out_sphere="sphere.txt", out_b="B.nii.gz", out_qa="qa.nii.gz", ): """Convert pam5 files to multiple nifti files. Parameters ---------- pam_files : string Path to the input peaks volumes. This path may contain wildcards to process multiple inputs at once. out_dir : string, optional Output directory (default input file directory). out_peaks_dir : string, optional Name of the peaks directions volume to be saved. out_peaks_values : string, optional Name of the peaks values volume to be saved. out_peaks_indices : string, optional Name of the peaks indices volume to be saved. out_shm : string, optional Name of the spherical harmonics volume to be saved. out_gfa : string, optional Generalized FA volume name to be saved. out_sphere : string, optional Sphere vertices name to be saved. out_b : string, optional Name of the B Matrix to be saved. out_qa : string, optional Name of the Quantitative Anisotropy file to be saved. """ io_it = self.get_io_iterator() msg = f"Nifti files saved in {out_dir or 'current directory'}" for ( ipam, opeaks_dir, opeaks_values, opeaks_indices, oshm, ogfa, osphere, ob, oqa, ) in io_it: logging.info("Converting %s file to niftis...", ipam) pam = load_pam(ipam) pam_to_niftis( pam, fname_peaks_dir=opeaks_dir, fname_shm=oshm, fname_peaks_values=opeaks_values, fname_peaks_indices=opeaks_indices, fname_sphere=osphere, fname_gfa=ogfa, fname_b=ob, fname_qa=oqa, ) logging.info(msg)