import enum
import importlib
from inspect import getmembers, isfunction
import os
from pathlib import Path
import re
import shutil
import sys
import warnings
import nibabel as nib
import numpy as np
import trx.trx_file_memmap as tmm
from dipy.core.gradients import (
extract_b0,
extract_dwi_shell,
gradient_table,
mask_non_weighted_bvals,
)
from dipy.core.sphere import Sphere
from dipy.data import get_sphere
from dipy.io.gradients import read_bvals_bvecs
from dipy.io.image import load_nifti, save_nifti
from dipy.io.peaks import (
load_pam,
niftis_to_pam,
pam_to_niftis,
tensor_to_pam,
)
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.io.utils import split_filename_extension
from dipy.reconst.shm import convert_sh_descoteaux_tournier, order_from_ncoef
from dipy.reconst.utils import convert_tensors
from dipy.tracking.streamlinespeed import length
from dipy.utils.logging import logger
from dipy.utils.optpkg import optional_package
from dipy.utils.tractogram import concatenate_tractogram
from dipy.workflows.base import format_key_value_table
from dipy.workflows.utils import handle_vol_idx
from dipy.workflows.workflow import Workflow
ne, have_ne, _ = optional_package("numexpr")
[docs]
class StatsPropertyName(enum.Enum):
"""Statistical data property names."""
MIN = "min"
MAX = "max"
MEDIAN = "median"
MEAN = "mean"
STD_DEV = "std dev"
[docs]
class PercentilePropertyName(enum.Enum):
"""Percentile data property names."""
PERCENTILE_2 = "2nd percentile"
PERCENTILE_98 = "98th percentile"
[docs]
class VolumetricPropertyName(enum.Enum):
"""Volumetric data property names."""
AFFINE = "Affine matrix"
VOXEL_ORDER = "Voxel order"
DATA_TYPE = "Data type"
DIMENSIONS = "Dimensions"
VOXEL_SIZE = "Voxel size"
[docs]
class BvalPropertyName(enum.Enum):
"""b-value data property names."""
B0_THRESHOLD = "b0 threshold"
B_VALUES = "b-values"
NUMBER_B0s = "Number of b0s"
NUMBER_B_VALUES = "Total number of b-values"
NUMBER_SHELLS = "Number of gradient shells"
[docs]
class BvecPropertyName(enum.Enum):
"""b-vector data property names."""
B_VECTORS = "b-vectors"
B_VECTORS_SHAPE = "Shape of b-vectors on disk"
NUMBER_NONUNIT_B_VECTORS = "Total number of non-unit b-vectors"
NUMBER_UNIT_B_VECTORS = "Total number of unit b-vectors"
[docs]
class TractographyPropertyName(enum.Enum):
"""Tractography data property names."""
DPP_KEYS = "Data per point keys"
DPS_KEYS = "Data per streamline keys"
LENGTH_MM = "Length (mm)"
LENGTH_NUM_PTS = "Length (nb points)"
NUM_STRML = "Number of streamlines"
ORIGIN = "Origin"
SPACE = "Space"
STEP_SIZE = "Step size (mm)"
[docs]
class PamPropertyName(enum.Enum):
"""PAM5 (PeaksAndMetrics) data property names."""
VERSION = "PAM5 version"
DIMENSIONS = "Volume dimensions"
VOXEL_SIZE = "Voxel size"
VOXEL_ORDER = "Voxel order"
NUM_PEAKS = "Peaks per voxel (max)"
PEAK_COVERAGE = "Peak coverage (fraction)"
MEAN_PEAKS = "Mean peaks per non-empty voxel"
PEAK_DIRS_SHAPE = "Peak dirs shape"
PEAK_VALUES_SHAPE = "Peak values shape"
PEAK_INDICES_SHAPE = "Peak indices shape"
SPHERE_VERTICES = "Sphere vertices"
SHM_COEFF_SHAPE = "SH coefficients shape"
SH_ORDER = "SH order"
B_SHAPE = "B matrix shape"
GFA_SHAPE = "GFA shape"
QA_SHAPE = "QA shape"
ODF_SHAPE = "ODF shape"
TOTAL_WEIGHT = "Total weight"
ANG_THR = "Angular threshold"
AFFINE = "Affine matrix"
def _print_property_information(prop, val, alignment_space):
"""Print a property, value pair left-aligned at the given width.
Parameters
----------
prop : str
Name of the property.
val : scalar, str, ndarray, list
Value to be property.
alignment_space : int
Character width for the property, value pair alignment.
"""
logger.info(f"{prop + ':':<{alignment_space}}{val}")
def _print_stats_information(data, alignment_space, tab):
"""Print statistical information left-aligned at the given width.
Prints minimum, maximum, median, mean and standard deviation values of
the given data indented by ``tab`` number of whitespaces.
Parameters
----------
data : ndarray
Name of the piece of information.
alignment_space : int
Character width for the property, value pair alignment.
tab : str
Whitespace characters corresponding to a tab character for the
indentation.
"""
logger.info(f"{tab + 'min:':<{alignment_space}}{np.min(data)}")
logger.info(f"{tab + 'max:':<{alignment_space}}{np.max(data)}")
logger.info(f"{tab + 'median:':<{alignment_space}}{np.median(data)}")
logger.info(f"{tab + 'mean:':<{alignment_space}}{np.mean(data)}")
logger.info(f"{tab + 'std dev:':<{alignment_space}}{np.std(data)}")
def _print_volumetric_information(data, affine, vox_sz, affcodes, alignment_space, tab):
"""Print volumetric information.
Parameters
----------
data : ndarray
Data whose properties are to be printed.
affine : ndarray
Affine matrix.
vox_sz : tuple
Voxel size.
affcodes : tuple
Voxel order (anatomical coordinate system).
alignment_space : int
Character width for the property, value pair alignment.
tab : str
Whitespace characters corresponding to a tab character for the
indentation.
"""
def _print_voxel_information(_name, _data, _alignment_space, _tab):
"""Print voxel information.
Parameters
----------
_name : str
Name of the information piece.
_data : ndarray
Data whose properties are to be printed.
_alignment_space : int
Character width for the property, value pair alignment.
_tab : str
Whitespace characters corresponding to a tab character
for the indentation.
"""
logger.info(f"{_name}:")
_print_stats_information(_data, _alignment_space, _tab)
_print_property_information(
_tab + PercentilePropertyName.PERCENTILE_2.value,
np.percentile(_data, 2),
_alignment_space,
)
_print_property_information(
_tab + PercentilePropertyName.PERCENTILE_98.value,
np.percentile(_data, 98),
_alignment_space,
)
_print_property_information(
VolumetricPropertyName.DIMENSIONS.value, data.shape, alignment_space
)
_print_property_information(
VolumetricPropertyName.DATA_TYPE.value, data.dtype, alignment_space
)
if data.ndim == 3:
_print_voxel_information("Data", data, alignment_space, tab)
if data.ndim == 4:
_print_voxel_information("Data (0th vol)", data[..., 0], alignment_space, tab)
_print_property_information(
VolumetricPropertyName.VOXEL_ORDER.value,
"".join(affcodes),
alignment_space,
)
logger.info(f"Affine matrix:\n{affine}")
_print_property_information(
VolumetricPropertyName.VOXEL_SIZE.value,
tuple(map(float, vox_sz)),
alignment_space,
)
if np.sum(np.abs(np.diff(vox_sz))) > 0.1:
msg = "Voxel size is not isotropic. Please reslice.\n"
logger.warning(msg, stacklevel=2)
def _print_bval_data_information(bvals, b0_threshold, bshell_thr, alignment_space):
"""Print b-value information.
Parameters
----------
bvals : ndarray
b-values.
b0_threshold : float
b0 threshold.
bshell_thr : float
Threshold value to determine shells.
alignment_space : int
Character width for the property, value pair alignment.
"""
logger.info(f"{BvalPropertyName.B_VALUES.value}:\n{bvals}")
_print_property_information(
BvalPropertyName.NUMBER_B_VALUES.value, len(bvals), alignment_space
)
shells = np.sum(np.diff(np.sort(bvals)) > bshell_thr)
_print_property_information(
BvalPropertyName.NUMBER_SHELLS.value, shells, alignment_space
)
num_b0s = np.sum(bvals <= b0_threshold)
logger.info(
f"{BvalPropertyName.NUMBER_B0s.value + ':':<{alignment_space}}{num_b0s}"
f" ({BvalPropertyName.B0_THRESHOLD.value}: {b0_threshold})"
)
def _print_bvec_data_information(bvecs, bvecs_tol, alignment_space):
r"""Print b-vector information.
Parameters
----------
bvecs : ndarray
b-vectors.
bvecs_tol : float
Threshold used to check that
:math:`norm(\text{bvec}) = 1 \pm \text{bvecs_tol}` b-vectors are
unit vectors.
alignment_space : int
Character width for the property, value pair alignment.
"""
_print_property_information(
BvecPropertyName.B_VECTORS_SHAPE.value, bvecs.shape, alignment_space
)
rows, cols = bvecs.shape
if rows < cols:
bvecs = bvecs.T
logger.info(f"{BvecPropertyName.B_VECTORS.value}\n{bvecs}")
norms = np.array([np.linalg.norm(bvec) for bvec in bvecs])
res = np.where((norms <= 1 + bvecs_tol) & (norms >= 1 - bvecs_tol))
ncl1 = np.sum(norms < 1 - bvecs_tol)
_print_property_information(
BvecPropertyName.NUMBER_UNIT_B_VECTORS.value, len(res[0]), alignment_space
)
_print_property_information(
BvecPropertyName.NUMBER_NONUNIT_B_VECTORS.value, ncl1, alignment_space
)
def _print_tractography_information(
sft, lengths_mm, lengths, step_size, alignment_space, tab
):
"""Print tractogram information.
Parameters
----------
sft : StatefulTractogram
Tractogram.
lengths_mm : list
Length of the streamlines in millimeters.
lengths : list
Length of the streamlines in number of points.
step_size : ndarray
Step sizes.
alignment_space : int
Character width for the property, value pair alignment.
tab : str
Whitespace characters corresponding to a tab character for the
indentation.
"""
_print_property_information(
TractographyPropertyName.NUM_STRML.value, len(sft), alignment_space
)
logger.info(f"{TractographyPropertyName.LENGTH_MM.value}:")
_print_stats_information(lengths_mm, alignment_space, tab)
logger.info(f"{TractographyPropertyName.LENGTH_NUM_PTS.value}:")
_print_stats_information(lengths, alignment_space, tab)
logger.info(f"{TractographyPropertyName.STEP_SIZE.value}:")
_print_stats_information(step_size, alignment_space, tab)
_print_property_information(
TractographyPropertyName.DPP_KEYS.value,
list(sft.data_per_point.keys()),
alignment_space,
)
_print_property_information(
TractographyPropertyName.DPS_KEYS.value,
list(sft.data_per_streamline.keys()),
alignment_space,
)
logger.info(f"Affine matrix:\n{sft.affine}")
_print_property_information(
VolumetricPropertyName.DIMENSIONS.value,
tuple(map(int, sft.dimensions)),
alignment_space,
)
_print_property_information(
VolumetricPropertyName.DATA_TYPE.value,
sft.streamlines.get_data().dtype,
alignment_space,
)
_print_property_information(
TractographyPropertyName.ORIGIN.value, sft.origin, alignment_space
)
_print_property_information(
TractographyPropertyName.SPACE.value, sft.space, alignment_space
)
_print_property_information(
VolumetricPropertyName.VOXEL_ORDER.value,
sft.voxel_order,
alignment_space,
)
_print_property_information(
VolumetricPropertyName.VOXEL_SIZE.value,
tuple(map(float, sft.voxel_sizes)),
alignment_space,
)
def _print_pam_information(pam, version, alignment_space, tab):
"""Print PAM5 (PeaksAndMetrics) information.
Parameters
----------
pam : PeaksAndMetrics
Loaded PAM5 object.
version : str
PAM5 file format version.
alignment_space : int
Character width for the property, value pair alignment.
tab : str
Whitespace characters corresponding to a tab character for the
indentation.
"""
_print_property_information(PamPropertyName.VERSION.value, version, alignment_space)
_print_property_information(
PamPropertyName.DIMENSIONS.value,
tuple(map(int, pam.peak_dirs.shape[:3])),
alignment_space,
)
if hasattr(pam, "affine") and pam.affine is not None:
vox_sz = np.sqrt(np.sum(pam.affine[:3, :3] ** 2, axis=0))
_print_property_information(
PamPropertyName.VOXEL_SIZE.value,
tuple(map(float, vox_sz)),
alignment_space,
)
_print_property_information(
PamPropertyName.VOXEL_ORDER.value,
"".join(nib.aff2axcodes(pam.affine)),
alignment_space,
)
else:
_print_property_information(
PamPropertyName.VOXEL_SIZE.value, "Not available", alignment_space
)
_print_property_information(
PamPropertyName.VOXEL_ORDER.value, "Not available", alignment_space
)
_print_property_information(
PamPropertyName.NUM_PEAKS.value,
int(pam.peak_dirs.shape[3]),
alignment_space,
)
nonzero_peaks = pam.peak_values > 0
voxel_has_peak = np.any(nonzero_peaks, axis=-1)
n_voxels = int(np.prod(pam.peak_values.shape[:-1]))
n_nonempty = int(voxel_has_peak.sum())
coverage = n_nonempty / n_voxels if n_voxels else 0.0
_print_property_information(
PamPropertyName.PEAK_COVERAGE.value,
f"{coverage:.4f} ({n_nonempty}/{n_voxels})",
alignment_space,
)
if n_nonempty:
mean_peaks = float(nonzero_peaks.sum() / n_nonempty)
else:
mean_peaks = 0.0
_print_property_information(
PamPropertyName.MEAN_PEAKS.value,
f"{mean_peaks:.3f}",
alignment_space,
)
_print_property_information(
PamPropertyName.PEAK_DIRS_SHAPE.value,
pam.peak_dirs.shape,
alignment_space,
)
_print_property_information(
PamPropertyName.PEAK_VALUES_SHAPE.value,
pam.peak_values.shape,
alignment_space,
)
_print_property_information(
PamPropertyName.PEAK_INDICES_SHAPE.value,
pam.peak_indices.shape,
alignment_space,
)
_print_property_information(
PamPropertyName.SPHERE_VERTICES.value,
int(pam.sphere.vertices.shape[0]),
alignment_space,
)
logger.info("Peak values stats:")
_print_stats_information(pam.peak_values, alignment_space, tab)
if hasattr(pam, "shm_coeff") and pam.shm_coeff is not None:
_print_property_information(
PamPropertyName.SHM_COEFF_SHAPE.value,
pam.shm_coeff.shape,
alignment_space,
)
try:
sh_order = order_from_ncoef(pam.shm_coeff.shape[-1])
_print_property_information(
PamPropertyName.SH_ORDER.value, sh_order, alignment_space
)
except ValueError:
logger.warning(
f"Could not derive SH order from shm_coeff with "
f"{pam.shm_coeff.shape[-1]} coefficients."
)
_print_property_information(
PamPropertyName.SH_ORDER.value, "Not available", alignment_space
)
else:
_print_property_information(
PamPropertyName.SHM_COEFF_SHAPE.value, "Not available", alignment_space
)
_print_property_information(
PamPropertyName.SH_ORDER.value, "Not available", alignment_space
)
if hasattr(pam, "B") and pam.B is not None:
_print_property_information(
PamPropertyName.B_SHAPE.value, pam.B.shape, alignment_space
)
else:
_print_property_information(
PamPropertyName.B_SHAPE.value, "Not available", alignment_space
)
if hasattr(pam, "gfa") and pam.gfa is not None:
_print_property_information(
PamPropertyName.GFA_SHAPE.value, pam.gfa.shape, alignment_space
)
logger.info("GFA stats:")
_print_stats_information(pam.gfa, alignment_space, tab)
else:
_print_property_information(
PamPropertyName.GFA_SHAPE.value, "Not available", alignment_space
)
if hasattr(pam, "qa") and pam.qa is not None:
_print_property_information(
PamPropertyName.QA_SHAPE.value, pam.qa.shape, alignment_space
)
logger.info("QA stats:")
_print_stats_information(pam.qa, alignment_space, tab)
else:
_print_property_information(
PamPropertyName.QA_SHAPE.value, "Not available", alignment_space
)
if hasattr(pam, "odf") and pam.odf is not None:
_print_property_information(
PamPropertyName.ODF_SHAPE.value, pam.odf.shape, alignment_space
)
else:
_print_property_information(
PamPropertyName.ODF_SHAPE.value, "Not available", alignment_space
)
if hasattr(pam, "total_weight") and pam.total_weight is not None:
_print_property_information(
PamPropertyName.TOTAL_WEIGHT.value, pam.total_weight, alignment_space
)
else:
_print_property_information(
PamPropertyName.TOTAL_WEIGHT.value, "Not available", alignment_space
)
if hasattr(pam, "ang_thr") and pam.ang_thr is not None:
_print_property_information(
PamPropertyName.ANG_THR.value, pam.ang_thr, alignment_space
)
else:
_print_property_information(
PamPropertyName.ANG_THR.value, "Not available", alignment_space
)
if hasattr(pam, "affine") and pam.affine is not None:
logger.info(f"{PamPropertyName.AFFINE.value}:\n{pam.affine}")
else:
_print_property_information(
PamPropertyName.AFFINE.value, "Not available", alignment_space
)
[docs]
class IoInfoFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "io_info"
[docs]
def run(
self,
input_files,
b0_threshold=50,
bvecs_tol=0.01,
bshell_thr=100,
reference=None,
):
r"""Provides useful information about different files used in
medical imaging. Any number of input files can be provided. The
program identifies the type of file by its extension.
Parameters
----------
input_files : variable string or Path
Any number of NIfTI, bvals, bvecs, tractography or PAM5
(``*.pam5``) data files.
b0_threshold : float, optional
Threshold used to find b0 volumes.
bvecs_tol : float, optional
Threshold used to check that
:math:`norm(\text{bvec}) = 1 \pm \text{bvecs_tol}` b-vectors are
unit vectors.
bshell_thr : float, optional
Threshold for distinguishing b-values in different shells.
reference : string or Path, optional
Reference anatomy for ``*.tck``, ``*.vtk``/``*.vtp``, ``*.fib``, and
``*.dpy`` tractography files.
"""
np.set_printoptions(3, suppress=True)
io_it = self.get_io_iterator()
vol_property_length = (
len(max([item.value for item in VolumetricPropertyName], key=len)) + 2
)
stats_property_length = (
len(max([item.value for item in StatsPropertyName], key=len)) + 2
)
pctl_property_length = (
len(max([item.value for item in PercentilePropertyName], key=len)) + 2
)
tab = "\t".expandtabs(4)
for input_path in io_it:
input_path = Path(input_path)
mult_ = len(str(input_path))
logger.info(f"-----------{mult_ * '-'}")
logger.info(f"Looking at {input_path}")
logger.info(f"-----------{mult_ * '-'}")
_, extension = split_filename_extension(input_path)
extension = extension.lower()
if extension in [".nii", ".nii.gz"]:
data, affine, img, vox_sz, affcodes = load_nifti(
input_path, return_img=True, return_voxsize=True, return_coords=True
)
apply_tab_offset = bool(
max(
range(3),
key=lambda i: [
vol_property_length,
stats_property_length,
pctl_property_length,
][i],
)
)
_print_volumetric_information(
data,
affine,
vox_sz,
affcodes,
max(
vol_property_length, stats_property_length, pctl_property_length
)
+ apply_tab_offset * len(tab),
tab,
)
if "bval" in input_path.name.lower():
bval_property_length = (
len(max([item.value for item in BvalPropertyName], key=len)) + 2
)
bvals = np.loadtxt(input_path)
_print_bval_data_information(
bvals, b0_threshold, bshell_thr, bval_property_length
)
if "bvec" in input_path.name.lower():
bvec_property_length = (
len(max([item.value for item in BvecPropertyName], key=len)) + 2
)
bvecs = np.loadtxt(input_path)
_print_bvec_data_information(bvecs, bvecs_tol, bvec_property_length)
if extension in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]:
tractogr_property_length = (
len(max([item.value for item in TractographyPropertyName], key=len))
+ 2
)
apply_tab_offset = not bool(
max(
range(3),
key=lambda i: [
stats_property_length,
vol_property_length,
tractogr_property_length,
][i],
)
)
if extension in [".trk", ".trx"]:
sft = load_tractogram(input_path, "same", bbox_valid_check=False)
else:
if not reference or not Path(reference).exists():
msg = (
"No reference provided. It is needed for tck, fib, dpy or "
"vtk files to load properly. Please provide a reference, "
"Nifti or Trk file using the option "
"--reference my_files.nii.gz ."
)
logger.error(msg, stacklevel=2)
sys.exit(1)
sft = load_tractogram(input_path, reference, bbox_valid_check=False)
lengths_mm = list(length(sft.streamlines))
sft.to_voxmm()
lengths, steps = [], []
for streamline in sft.streamlines:
lengths += [len(streamline)]
steps += [np.sqrt(np.sum(np.diff(streamline, axis=0) ** 2, axis=1))]
steps = np.hstack(steps)
_print_tractography_information(
sft,
lengths_mm,
lengths,
steps,
max(
vol_property_length,
stats_property_length,
tractogr_property_length,
)
+ apply_tab_offset * len(tab),
tab,
)
if extension == ".pam5":
pam_property_length = (
len(max([item.value for item in PamPropertyName], key=len)) + 2
)
pam = load_pam(input_path)
_print_pam_information(
pam,
"0.0.1",
max(pam_property_length, stats_property_length) + len(tab),
tab,
)
np.set_printoptions()
[docs]
class FetchFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "fetch"
[docs]
@staticmethod
def get_fetcher_datanames():
"""Gets available dataset and function names.
Returns
-------
available_data: dict
Available dataset and function names.
"""
fetcher_module = FetchFlow.load_module("dipy.data.fetcher")
available_data = dict(
{
(name.replace("fetch_", ""), func)
for name, func in getmembers(fetcher_module, isfunction)
if name.lower().startswith("fetch_")
and func is not fetcher_module.fetch_data
}
)
return available_data
[docs]
@staticmethod
def load_module(module_path):
"""Load / reload an external module.
Parameters
----------
module_path: string
the path to the module relative to the main script
Returns
-------
module: module object
"""
if module_path in sys.modules:
return importlib.reload(sys.modules[module_path])
else:
return importlib.import_module(module_path)
[docs]
def run(
self,
data_names,
subjects=None,
include_optional=False,
include_afq=False,
hcp_bucket="hcp-openaccess",
hcp_profile_name="hcp",
hcp_study="HCP_1200",
hcp_aws_access_key_id=None,
hcp_aws_secret_access_key=None,
out_dir="",
):
"""Download files to folder and check their md5 checksums.
- To see all available datasets, please type "dipy_fetch list".
- To download all datasets, please type "dipy_fetch all".
Parameters
----------
data_names : variable string
One or more dataset names to fetch. Use ``list`` to print all
available names or ``all`` to fetch everything. Available names
include (among others): ``atlas_schaefer_2018``,
``bundle_atlas_hcp842``, ``stanford_hardi``, ``mni_template``.
subjects : variable string, optional
Identifiers of the subjects to download. Used only by the HBN & HCP dataset.
For example with HBN dataset: --subject NDARAA948VFH NDAREK918EC2
include_optional : bool, optional
Include optional datasets.
include_afq : bool, optional
Whether to include pyAFQ derivatives. Used only by the HBN dataset.
hcp_bucket : string, optional
The name of the HCP S3 bucket.
hcp_profile_name : string, optional
The name of the AWS profile used for access.
hcp_study : string, optional
Which HCP study to grab.
hcp_aws_access_key_id : string, optional
AWS credentials to HCP AWS S3. Will only be used if `profile_name` is
set to False.
hcp_aws_secret_access_key : string, optional
AWS credentials to HCP AWS S3. Will only be used if `profile_name` is
set to False.
out_dir : string or Path, optional
Output directory.
"""
if out_dir:
dipy_home = os.environ.get("DIPY_HOME", None)
os.environ["DIPY_HOME"] = out_dir
available_data = FetchFlow.get_fetcher_datanames()
data_names = [name.lower() for name in data_names]
if "all" in data_names:
logger.warning("Skipping HCP and HBN datasets.")
available_data.pop("hcp", None)
available_data.pop("hbn", None)
for name, fetcher_function in available_data.items():
if name in ["hcp", "hbn"]:
continue
logger.info("------------------------------------------")
logger.info(f"Fetching at {name}")
logger.info("------------------------------------------")
fetcher_function(include_optional=include_optional)
elif "list" in data_names:
logger.info(
"Please, select between the following data names: \n"
f"{format_data_names_table(available_data)}"
)
else:
skipped_names = []
for data_name in data_names:
if data_name not in available_data.keys():
skipped_names.append(data_name)
continue
logger.info("------------------------------------------")
logger.info(f"Fetching at {data_name}")
logger.info("------------------------------------------")
if data_name == "hcp":
if not subjects:
logger.error(
"Please provide the subjects to download the HCP dataset."
)
continue
try:
available_data[data_name](
subjects=subjects,
hcp_bucket=hcp_bucket,
profile_name=hcp_profile_name,
study=hcp_study,
aws_access_key_id=hcp_aws_access_key_id,
aws_secret_access_key=hcp_aws_secret_access_key,
)
except Exception as e:
logger.error(
f"Error while fetching HCP dataset: {e}", exc_info=True
)
elif data_name == "hbn":
if not subjects:
logger.error(
"Please provide the subjects to download the HBN dataset."
)
continue
try:
available_data[data_name](
subjects=subjects, include_afq=include_afq
)
except Exception as e:
logger.error(
f"Error while fetching HBN dataset: {e}", exc_info=True
)
else:
available_data[data_name](include_optional=include_optional)
nb_success = len(data_names) - len(skipped_names)
logger.info("\n")
logger.info(f"Fetched {nb_success} / {len(data_names)} Datasets ")
if skipped_names:
logger.warning(f"Skipped data name(s): {' '.join(skipped_names)}")
logger.warning(
"Please, select between the following data names: "
f"{', '.join(available_data.keys())}"
)
if out_dir:
if dipy_home:
os.environ["DIPY_HOME"] = dipy_home
else:
os.environ.pop("DIPY_HOME", None)
# We load the module again so that if we run another one of these
# in the same process, we don't have the env variable pointing
# to the wrong place
self.load_module("dipy.data.fetcher")
[docs]
class SplitFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "split"
[docs]
def run(self, input_files, vol_idx=0, out_dir="", out_split="split.nii.gz"):
"""Splits the input 4D file and extracts the required 3D volume.
Parameters
----------
input_files : variable string or Path
Any number of Nifti1 files
vol_idx : int, optional
Index of the 3D volume to extract.
out_dir : string, optional
Output directory.
out_split : string or Path, optional
Name of the resulting split volume
"""
io_it = self.get_io_iterator()
for fpath, osplit in io_it:
logger.info(f"Splitting {fpath}")
data, affine, image = load_nifti(fpath, return_img=True)
if vol_idx == 0:
logger.info("Splitting and extracting 1st b0")
split_vol = data[..., vol_idx]
save_nifti(osplit, split_vol, affine, hdr=image.header)
logger.info(f"Split volume saved as {osplit}")
[docs]
class ConcatenateTractogramFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "concatracks"
[docs]
def run(
self,
tractogram_files,
reference=None,
delete_dpv=False,
delete_dps=False,
delete_groups=False,
check_space_attributes=True,
preallocation=False,
out_dir="",
out_extension="trx",
out_tractogram="concatenated_tractogram",
):
"""Concatenate multiple tractograms into one.
Parameters
----------
tractogram_list : variable string or Path
The stateful tractogram filenames to concatenate
reference : string or Path, optional
Reference anatomy for tck/vtk/fib/dpy file.
support (.nii or .nii.gz).
delete_dpv : bool, optional
Delete dpv keys that do not exist in all the provided TrxFiles
delete_dps : bool, optional
Delete dps keys that do not exist in all the provided TrxFile
delete_groups : bool, optional
Delete all the groups that currently exist in the TrxFiles
check_space_attributes : bool, optional
Verify that dimensions and size of data are similar between all the
TrxFiles
preallocation : bool, optional
Preallocated TrxFile has already been generated and is the first
element in trx_list (Note: delete_groups must be set to True as
well)
out_dir : string or Path, optional
Output directory.
out_extension : string, optional
Extension of the resulting tractogram
out_tractogram : string, optional
Name of the resulting tractogram
"""
io_it = self.get_io_iterator()
trx_list = []
has_group = False
for fpath, _, _ in io_it:
_, extension = split_filename_extension(fpath)
if extension.lower() in [".trx", ".trk"]:
reference = "same"
if not reference:
raise ValueError(
"No reference provided. It is needed for tck,fib, dpy or vtk files"
)
tractogram_obj = load_tractogram(fpath, reference, bbox_valid_check=False)
if not isinstance(tractogram_obj, tmm.TrxFile):
tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj)
elif len(tractogram_obj.groups):
has_group = True
trx_list.append(tractogram_obj)
trx = concatenate_tractogram(
trx_list,
delete_dpv=delete_dpv,
delete_dps=delete_dps,
delete_groups=delete_groups or not has_group,
check_space_attributes=check_space_attributes,
preallocation=preallocation,
)
valid_extensions = ["trk", "trx", "tck", "fib", "dpy", "vtk"]
if out_extension.lower() not in valid_extensions:
raise ValueError(
f"Invalid extension. Valid extensions are: {valid_extensions}"
)
out_fpath = Path(out_dir) / f"{out_tractogram}.{out_extension}"
save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False)
[docs]
class ConvertSHFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "convert_dipy_mrtrix"
[docs]
def run(
self,
input_files,
out_dir="",
out_file="sh_convert_dipy_mrtrix_out.nii.gz",
):
"""Converts SH basis representation between DIPY and MRtrix3 formats.
Because this conversion is equal to its own inverse, it can be used to
convert in either direction: DIPY to MRtrix3 or vice versa.
Parameters
----------
input_files : string or Path
Path to the input files. This path may contain wildcards to
process multiple inputs at once.
out_dir : string or Path, optional
Where the resulting file will be saved. (default '')
out_file : string, optional
Name of the result file to be saved.
(default 'sh_convert_dipy_mrtrix_out.nii.gz')
"""
io_it = self.get_io_iterator()
for in_file, out_file in io_it:
data, affine, image = load_nifti(in_file, return_img=True)
data = convert_sh_descoteaux_tournier(data)
save_nifti(out_file, data, affine, hdr=image.header)
[docs]
class ConvertTensorsFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "convert_tensors"
[docs]
def run(
self,
tensor_files,
from_format="mrtrix",
to_format="dipy",
out_dir=".",
out_tensor="converted_tensor",
):
"""Converts tensor representation between different formats.
Parameters
----------
tensor_files : variable string or Path
Any number of tensor files
from_format : string, optional
Format of the input tensor files. Valid options are 'dipy',
'mrtrix', 'ants', 'fsl'.
to_format : string, optional
Format of the output tensor files. Valid options are 'dipy',
'mrtrix', 'ants', 'fsl'.
out_dir : string or Path, optional
Output directory.
out_tensor : string, optional
Name of the resulting tensor file
"""
io_it = self.get_io_iterator()
for fpath, otensor in io_it:
logger.info(f"Converting {fpath}")
data, affine, image = load_nifti(fpath, return_img=True)
data = convert_tensors(data, from_format, to_format)
save_nifti(otensor, data, affine, hdr=image.header)
[docs]
class ConvertTractogramFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "convert_tractogram"
[docs]
def run(
self,
input_files,
reference=None,
pos_dtype="float32",
offsets_dtype="uint32",
out_dir="",
out_tractogram="converted_tractogram.trk",
):
"""Converts tractogram between different formats.
Parameters
----------
input_files : variable string or Path
Any number of tractogram files
reference : string, optional
Reference anatomy for tck/vtk/fib/dpy file.
support (.nii or .nii.gz).
pos_dtype : string, optional
Data type of the tractogram points, used for vtk files.
offsets_dtype : string, optional
Data type of the tractogram offsets, used for vtk files.
out_dir : string or Path, optional
Output directory.
out_tractogram : string, optional
Name of the resulting tractogram
"""
io_it = self.get_io_iterator()
for fpath, otracks in io_it:
in_extension = Path(fpath).suffix.lower()
out_extension = Path(otracks).suffix.lower()
if in_extension == out_extension:
warnings.warn(
"Input and output are the same file format. Skipping...",
stacklevel=2,
)
continue
if not reference and in_extension in [".trx", ".trk"]:
reference = "same"
if not reference and in_extension not in [".trx", ".trk"]:
raise ValueError(
"No reference provided. It is needed for tck,fib, dpy or vtk files"
)
sft = load_tractogram(fpath, reference, bbox_valid_check=False)
if out_extension != ".trx":
if out_extension == ".vtk":
if sft.streamlines._data.dtype.name != pos_dtype:
sft.streamlines._data = sft.streamlines._data.astype(pos_dtype)
if offsets_dtype == "uint64" or offsets_dtype == "uint32":
offsets_dtype = offsets_dtype[1:]
if sft.streamlines._offsets.dtype.name != offsets_dtype:
sft.streamlines._offsets = sft.streamlines._offsets.astype(
offsets_dtype
)
save_tractogram(sft, otracks, bbox_valid_check=False)
else:
trx = tmm.TrxFile.from_sft(sft)
if trx.streamlines._data.dtype.name != pos_dtype:
trx.streamlines._data = trx.streamlines._data.astype(pos_dtype)
if trx.streamlines._offsets.dtype.name != offsets_dtype:
trx.streamlines._offsets = trx.streamlines._offsets.astype(
offsets_dtype
)
tmm.save(trx, otracks)
trx.close()
[docs]
class NiftisToPamFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "niftis_to_pam"
[docs]
def run(
self,
peaks_dir_files,
peaks_values_files,
peaks_indices_files,
shm_files=None,
gfa_files=None,
sphere_files=None,
default_sphere_name="repulsion724",
out_dir="",
out_pam="peaks.pam5",
):
"""Convert multiple nifti files to a single pam5 file.
Parameters
----------
peaks_dir_files : string or Path
Path to the input peaks directions volume. This path may contain
wildcards to process multiple inputs at once.
peaks_values_files : string or Path
Path to the input peaks values volume. This path may contain
wildcards to process multiple inputs at once.
peaks_indices_files : string or Path
Path to the input peaks indices volume. This path may contain
wildcards to process multiple inputs at once.
shm_files : string, optional or Path
Path to the input spherical harmonics volume. This path may
contain wildcards to process multiple inputs at once.
gfa_files : string or Path, optional
Path to the input generalized FA volume. This path may contain
wildcards to process multiple inputs at once.
sphere_files : string or Path, optional
Path to the input sphere vertices. This path may contain
wildcards to process multiple inputs at once. If it is not define,
default_sphere option will be used.
default_sphere_name : string, optional
Specify default sphere to use for spherical harmonics
representation. This option can be superseded by
sphere_files option. Possible options: ['symmetric362', 'symmetric642',
'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'].
out_dir : string or Path, optional
Output directory (default input file directory).
out_pam : string, optional
Name of the peaks volume to be saved.
"""
io_it = self.get_io_iterator()
msg = f"pam5 files saved in {out_dir or 'current directory'}"
for fpeak_dirs, fpeak_values, fpeak_indices, opam in io_it:
logger.info("Converting nifti files to pam5")
peak_dirs, affine = load_nifti(fpeak_dirs)
peak_values, _ = load_nifti(fpeak_values)
peak_indices, _ = load_nifti(fpeak_indices)
if sphere_files:
xyz = np.loadtxt(sphere_files)
sphere = Sphere(xyz=xyz)
else:
sphere = get_sphere(name=default_sphere_name)
niftis_to_pam(
affine=affine,
peak_dirs=peak_dirs,
sphere=sphere,
peak_values=peak_values,
peak_indices=peak_indices,
pam_file=opam,
)
logger.info(msg.replace("pam5", str(opam)))
[docs]
class TensorToPamFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "tensor_to_niftis"
[docs]
def run(
self,
evals_files,
evecs_files,
sphere_files=None,
default_sphere_name="repulsion724",
out_dir="",
out_pam="peaks.pam5",
):
"""Convert multiple tensor files(evals, evecs) to pam5 files.
Parameters
----------
evals_files : string or Path
Path to the input eigen values volumes. This path may contain
wildcards to process multiple inputs at once.
evecs_files : string or Path
Path to the input eigen vectors volumes. This path may contain
wildcards to process multiple inputs at once.
sphere_files : string or Path, optional
Path to the input sphere vertices. This path may contain
wildcards to process multiple inputs at once. If it is not define,
default_sphere option will be used.
default_sphere_name : string, optional
Specify default sphere to use for spherical harmonics
representation. This option can be superseded by sphere_files
option. Possible options: ['symmetric362', 'symmetric642',
'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200'].
out_dir : string or Path, optional
Output directory (default input file directory).
out_pam : string, optional
Name of the peaks volume to be saved.
"""
io_it = self.get_io_iterator()
msg = f"pam5 files saved in {out_dir or 'current directory'}"
for fevals, fevecs, opam in io_it:
logger.info("Converting tensor files to pam5...")
evals, affine = load_nifti(fevals)
evecs, _ = load_nifti(fevecs)
if sphere_files:
xyz = np.loadtxt(sphere_files)
sphere = Sphere(xyz=xyz)
else:
sphere = get_sphere(name=default_sphere_name)
tensor_to_pam(evals, evecs, affine, sphere=sphere, pam_file=opam)
logger.info(msg.replace("pam5", str(opam)))
[docs]
class PamToNiftisFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "pam_to_niftis"
[docs]
def run(
self,
pam_files,
out_dir="",
out_peaks_dir="peaks_dirs.nii.gz",
out_peaks_values="peaks_values.nii.gz",
out_peaks_indices="peaks_indices.nii.gz",
out_shm="shm.nii.gz",
out_gfa="gfa.nii.gz",
out_sphere="sphere.txt",
out_b="B.nii.gz",
out_qa="qa.nii.gz",
):
"""Convert pam5 files to multiple nifti files.
Parameters
----------
pam_files : string
Path to the input peaks volumes. This path may contain wildcards to
process multiple inputs at once.
out_dir : string, optional
Output directory (default input file directory).
out_peaks_dir : string, optional
Name of the peaks directions volume to be saved.
out_peaks_values : string, optional
Name of the peaks values volume to be saved.
out_peaks_indices : string, optional
Name of the peaks indices volume to be saved.
out_shm : string, optional
Name of the spherical harmonics volume to be saved.
out_gfa : string, optional
Generalized FA volume name to be saved.
out_sphere : string, optional
Sphere vertices name to be saved.
out_b : string, optional
Name of the B Matrix to be saved.
out_qa : string, optional
Name of the Quantitative Anisotropy file to be saved.
"""
io_it = self.get_io_iterator()
msg = f"Nifti files saved in {out_dir or 'current directory'}"
for (
ipam,
opeaks_dir,
opeaks_values,
opeaks_indices,
oshm,
ogfa,
osphere,
ob,
oqa,
) in io_it:
logger.info(f"Converting file {ipam} to niftis...")
pam = load_pam(ipam)
pam_to_niftis(
pam,
fname_peaks_dir=opeaks_dir,
fname_shm=oshm,
fname_peaks_values=opeaks_values,
fname_peaks_indices=opeaks_indices,
fname_sphere=osphere,
fname_gfa=ogfa,
fname_b=ob,
fname_qa=oqa,
)
logger.info(msg)
[docs]
class MathFlow(Workflow):
[docs]
@classmethod
def get_short_name(cls):
return "math_flow"
def broadcast_arrays(self, vol_dict, operation):
variables = re.findall(r"\b\w+\b", operation)
variables = [var for var in variables if var in vol_dict]
shapes = [vol_dict[var].shape for var in variables]
max_dims = max(len(shape) for shape in shapes)
for var in variables:
shape = vol_dict[var].shape
if len(shape) < max_dims:
new_shape = shape + (1,) * (max_dims - len(shape))
vol_dict[var] = vol_dict[var].reshape(new_shape)
updated_shapes = [vol_dict[var].shape for var in variables]
try:
broadcast_shape = np.broadcast_shapes(*updated_shapes)
except ValueError as e:
raise ValueError(f"Shape mismatch after adding dimensions: {e}") from e
for var in variables:
if vol_dict[var].shape != broadcast_shape:
vol_dict[var] = np.broadcast_to(vol_dict[var], broadcast_shape)
return vol_dict
[docs]
def run(
self,
operation,
input_files,
dtype=None,
disable_check=False,
out_dir="",
out_file="math_out.nii.gz",
):
"""Perform mathematical operations on volume input files.
This workflow allows the user to perform mathematical operations on
multiple input files. e.g. to add two volumes together, subtract one:
``dipy_math "vol1 + vol2 - vol3" t1.nii.gz t1_a.nii.gz t1_b.nii.gz``
The input files must be in Nifti format and have the same shape.
Parameters
----------
operation : string
Mathematical operation to perform. supported operators are:
- Bitwise operators (and, or, not, xor): ``&, |, ~, ^``
- Comparison operators: ``<, <=, ==, !=, >=, >``
- Unary arithmetic operators: ``-``
- Binary arithmetic operators: ``+, -, *, /, **, <<, >>``
Supported functions are:
- ``where(bool, number1, number2) -> number``: number1 if the bool
condition is true, number2 otherwise.
- ``{sin,cos,tan}(float|complex) -> float|complex``: trigonometric sine,
cosine or tangent.
- ``{arcsin,arccos,arctan}(float|complex) -> float|complex``:
trigonometric inverse sine, cosine or tangent.
- ``arctan2(float1, float2) -> float``: trigonometric inverse tangent of
float1/float2.
- ``{sinh,cosh,tanh}(float|complex) -> float|complex``: hyperbolic
sine, cosine or tangent.
- ``{arcsinh,arccosh,arctanh}(float|complex) -> float|complex``:
hyperbolic inverse sine, cosine or tangent.
- ``{log,log10,log1p}(float|complex) -> float|complex``: natural,
base-10 and log(1+x) logarithms.
- ``{exp,expm1}(float|complex) -> float|complex``: exponential and
exponential minus one.
- ``sqrt(float|complex) -> float|complex``: square root.
- ``abs(float|complex) -> float|complex``: absolute value.
- ``conj(complex) -> complex``: conjugate value.
- ``{real,imag}(complex) -> float``: real or imaginary part of complex.
- ``complex(float, float) -> complex``: complex from real and imaginary
parts.
- ``contains(np.str, np.str) -> bool``: returns True for every string
in op1 that contains op2.
input_files : variable string or Path
Any number of Nifti1 files
dtype : string, optional
Data type of the resulting file.
disable_check : bool, optional
If True, the workflow will not check if all input files have the same
shape and affine matrix.
out_dir : string, optional
Output directory
out_file : string, optional
Name of the resulting file to be saved.
"""
vol_dict = {}
ref_affine = None
ref_shape = None
info_msg = ""
have_errors = False
for i, fname in enumerate(input_files, start=1):
if not Path(fname).is_file():
logger.error(f"Input file {fname} does not exist.")
raise SystemExit()
_, ext = split_filename_extension(fname)
if ext not in [".nii.gz", ".nii"]:
msg = (
f"Wrong volume type: {fname}. Only Nifti files are supported"
" (*.nii or *.nii.gz)."
)
logger.error(msg)
raise SystemExit()
data, affine = load_nifti(fname)
vol_dict[f"vol{i}"] = data
info_msg += f"{fname}:\n- vol index: {i}\n- shape: {data.shape}"
info_msg += f"\n- affine:\n{affine}\n"
if ref_affine is None:
ref_affine = affine
ref_shape = data.shape
continue
have_errors = (
have_errors
or not np.all(np.isclose(ref_affine, affine, rtol=1e-05, atol=1e-08))
or not np.array_equal(ref_shape, data.shape)
)
if have_errors:
logger.warning(info_msg)
if not disable_check:
msg = "All input files must have the same shape and affine matrix."
logger.error(msg)
raise SystemExit()
try:
if disable_check:
vol_dict = self.broadcast_arrays(vol_dict, operation)
res = ne.evaluate(operation, local_dict=vol_dict)
except KeyError as e:
msg = (
f"Impossible key {e} in the operation. You have {len(input_files)}"
f" volumes available with the following keys: {list(vol_dict.keys())}"
)
logger.error(msg)
raise SystemExit() from e
if dtype:
try:
res = res.astype(dtype)
except TypeError as e:
msg = (
f"Impossible to cast to {dtype}. Check possible numpy type here:"
"https://numpy.org/doc/stable/reference/arrays.interface.html"
)
logger.error(msg)
raise SystemExit() from e
if res.dtype == bool:
res = res.astype(np.uint8)
out_fname = Path(out_dir) / out_file
logger.info(f"Saving result to {out_fname}")
save_nifti(out_fname, res, affine)