An introduction to the Fast Tracking Module#

The fast tracking module allow to run tractography on multiple CPU cores.

Current implemented algorithms are probabilistic, deterministic and parallel transport tractography (PTT).

See An introduction to the Probabilistic Tractography An introduction to the Deterministic Tractography Parallel Transport Tractography for detailed examples of those algorithms.

import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import binary_erosion
from scipy.stats import pearsonr

from dipy.core.gradients import gradient_table
from dipy.data import default_sphere, get_fnames
from dipy.io.image import load_nifti, load_nifti_data
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_tractogram, save_trk
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel, auto_response_ssst
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from dipy.tracking.streamline import Streamlines
from dipy.tracking.tracker import (
    deterministic_tracking,
    probabilistic_tracking,
    ptt_tracking,
)
from dipy.tracking.utils import connectivity_matrix, seeds_from_mask
from dipy.viz import actor, colormap, has_fury, window

# Enables/disables interactive visualization
interactive = False

print("Downloading data...")
Downloading data...

Prepare the synthetic DiSCo data for fast tracking. The ground-truth connectome will be use to evaluate tractography performances.

fnames = get_fnames(name="disco1")
disco1_fnames = [os.path.basename(f) for f in fnames]

GT_connectome_fname = fnames[
    disco1_fnames.index("DiSCo1_Connectivity_Matrix_Cross-Sectional_Area.txt")
]
GT_connectome = np.loadtxt(GT_connectome_fname)

# select the non-zero connections of the upper triangular part of the connectome
connectome_mask = np.tril(np.ones(GT_connectome.shape), -1) > 0

# load the
labels_fname = fnames[disco1_fnames.index("highRes_DiSCo1_ROIs.nii.gz")]
labels, affine, labels_img = load_nifti(labels_fname, return_img=True)
labels = labels.astype(int)

print("Loading data...")
tracks_fname = fnames[disco1_fnames.index("DiSCo1_Strands_Trajectories.tck")]
GT_streams = load_tractogram(tracks_fname, reference=labels_img).streamlines
if has_fury:
    scene = window.Scene()
    scene.add(actor.line(GT_streams, colors=colormap.line_colors(GT_streams)))
    window.record(scene=scene, out_path="tractogram_ground_truth.png", size=(800, 800))
    if interactive:
        window.show(scene)

plt.imshow(GT_connectome, origin="lower", cmap="viridis", interpolation="nearest")
plt.axis("off")
plt.savefig("connectome_ground_truth.png")
plt.close()
  • tracking fast tractography
  • tracking fast tractography
Loading data...

DiSCo ground-truth trajectories (left) and connectivity matrix (right).

# Tracking mask, seed positions and initial directions
mask_fname = fnames[disco1_fnames.index("highRes_DiSCo1_mask.nii.gz")]
mask = load_nifti_data(mask_fname)
sc = BinaryStoppingCriterion(mask)

voxel_size = np.ones(3)
seed_fname = fnames[disco1_fnames.index("highRes_DiSCo1_ROIs-mask.nii.gz")]
seed_mask = load_nifti_data(seed_fname)
seed_mask = binary_erosion(seed_mask * mask, iterations=1)
seeds = seeds_from_mask(seed_mask, affine, density=2)

plt.imshow(seed_mask[:, :, 17], origin="lower", cmap="gray", interpolation="nearest")
plt.axis("off")
plt.title("Seeding Mask")
plt.savefig("seeding_mask.png")
plt.close()
plt.imshow(mask[:, :, 17], origin="lower", cmap="gray", interpolation="nearest")
plt.axis("off")
plt.title("Tracking Mask")
plt.savefig("tracking_mask.png")
plt.close()
  • tracking fast tractography
  • tracking fast tractography

DiSCo seeding (left) and tracking (right) masks.

# Compute ODFs
data_fname = fnames[disco1_fnames.index("highRes_DiSCo1_DWI_RicianNoise-snr10.nii.gz")]
data = load_nifti_data(data_fname)

bvecs = fnames[disco1_fnames.index("DiSCo_gradients_dipy.bvecs")]
bvals = fnames[disco1_fnames.index("DiSCo_gradients.bvals")]
gtab = gradient_table(bvals=bvals, bvecs=bvecs)

response, _ = auto_response_ssst(gtab, data, roi_radii=10, fa_thr=0.7)
csd_model = ConstrainedSphericalDeconvModel(gtab, response, sh_order_max=8)
csd_fit = csd_model.fit(data, mask=mask)
ODFs = csd_fit.odf(default_sphere)

if has_fury:
    scene = window.Scene()
    ODF_spheres = actor.odf_slicer(
        ODFs[:, :, 17:18, :],
        sphere=default_sphere,
        scale=2,
        norm=False,
        colormap="plasma",
    )
    scene.add(ODF_spheres)
    window.record(scene=scene, out_path="GT_odfs.png", size=(600, 600))
    if interactive:
        window.show(scene)
tracking fast tractography

DiSCo ground-truth ODFs.

# Perform fast deterministic tractography using 1 thread (cpu)
print("Running fast Deterministic Tractography...")
streamline_generator = deterministic_tracking(
    seeds, sc, affine, sf=ODFs, nbr_threads=1, random_seed=42, sphere=default_sphere
)

det_streams = Streamlines(streamline_generator)
sft = StatefulTractogram(det_streams, labels_img, Space.RASMM)
save_trk(sft, "tractogram_fast_deterministic.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(det_streams, colors=colormap.line_colors(det_streams)))
    window.record(
        scene=scene, out_path="tractogram_fast_deterministic.png", size=(800, 800)
    )
    if interactive:
        window.show(scene)

# Compare the estimated connectivity with the ground-truth connectivity
connectome = connectivity_matrix(det_streams, affine, labels)[1:, 1:]
r, _ = pearsonr(
    GT_connectome[connectome_mask].flatten(), connectome[connectome_mask].flatten()
)
print("DiSCo ground-truth correlation (fast deterministic tractography): ", r)

plt.imshow(connectome, origin="lower", cmap="viridis", interpolation="nearest")
plt.axis("off")
plt.savefig("connectome_deterministic.png")
plt.close()
  • tracking fast tractography
  • tracking fast tractography
Running fast Deterministic Tractography...
DiSCo ground-truth correlation (fast deterministic tractography):  0.8666522736708173

DiSCo Deterministic tractogram and corresponding connectome.

Perform fast probabilistic tractography using 4 threads (cpus)

print("Running fast Probabilistic Tractography...")
streamline_generator = probabilistic_tracking(
    seeds, sc, affine, sf=ODFs, nbr_threads=4, random_seed=42, sphere=default_sphere
)
prob_streams = Streamlines(streamline_generator)
sft = StatefulTractogram(prob_streams, labels_img, Space.RASMM)
save_trk(sft, "tractogram_fast_probabilistic.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(prob_streams, colors=colormap.line_colors(prob_streams)))
    window.record(
        scene=scene, out_path="tractogram_fast_probabilistic.png", size=(800, 800)
    )
    if interactive:
        window.show(scene)

# Compare the estimated connectivity with the ground-truth connectivity
connectome = connectivity_matrix(prob_streams, affine, labels)[1:, 1:]
r, _ = pearsonr(
    GT_connectome[connectome_mask].flatten(), connectome[connectome_mask].flatten()
)
print("DiSCo ground-truth correlation (fast probabilistic tractography): ", r)

plt.imshow(connectome, origin="lower", cmap="viridis", interpolation="nearest")
plt.axis("off")
plt.savefig("connectome_probabilistic.png")
plt.close()
  • tracking fast tractography
  • tracking fast tractography
Running fast Probabilistic Tractography...
DiSCo ground-truth correlation (fast probabilistic tractography):  0.8712155682513311

DiSCo Probabilistic tractogram and corresponding connectome.

# Perform fast parallel transport tractography tractography using all threads (cpus)
print("Running fast Parallel Transport Tractography...")
streamline_generator = ptt_tracking(
    seeds, sc, affine, sf=ODFs, nbr_threads=0, random_seed=42, sphere=default_sphere
)
ptt_streams = Streamlines(streamline_generator)
sft = StatefulTractogram(ptt_streams, labels_img, Space.RASMM)
save_trk(sft, "tractogram_fast_ptt.trk")

if has_fury:
    scene = window.Scene()
    scene.add(actor.line(ptt_streams, colors=colormap.line_colors(ptt_streams)))
    window.record(scene=scene, out_path="tractogram_fast_ptt.png", size=(800, 800))
    if interactive:
        window.show(scene)

# Compare the estimated connectivity with the ground-truth connectivity
connectome = connectivity_matrix(ptt_streams, affine, labels)[1:, 1:]
r, _ = pearsonr(
    GT_connectome[connectome_mask].flatten(), connectome[connectome_mask].flatten()
)
print("DiSCo ground-truth correlation (fast PTT tractography): ", r)
plt.imshow(connectome, origin="lower", cmap="viridis", interpolation="nearest")
plt.axis("off")
plt.savefig("connectome_ptt.png")
plt.close()
  • tracking fast tractography
  • tracking fast tractography
Running fast Parallel Transport Tractography...
DiSCo ground-truth correlation (fast PTT tractography):  0.8732068660131755

DiSCo PTT tractogram and corresponding connectome.

Total running time of the script: (0 minutes 46.881 seconds)

Gallery generated by Sphinx-Gallery