Note
Go to the end to download the full example code.
Saving and Loading QuickBundles Clustering Results#
This example shows how to save clustering results produced by QuickBundles [1] into tractogram files using two formats.
Part 1 — TRX uses the native TRX structure: cluster labels are stored in
dps/clusters_QB and each cluster is saved as a named entry under
groups/. This is the recommended approach because TRX natively supports
both concepts, uses memory-mapped arrays for efficiency, and avoids any
concatenation hack.
Part 2 — TRK uses a merge strategy: the centroids are appended to the
streamline array and an is_centroid flag is stored in
data_per_streamline. This is necessary because the TRK format has no
native groups concept.
import numpy as np
import trx.trx_file_memmap as tmm
from dipy.data import get_fnames
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.segment.clustering import QuickBundles
Generating Clusters and Centroids#
We fetch a sample fornix dataset and group the streamlines into clusters
using the QuickBundles algorithm. QuickBundles produces a ClusterMap
that stores per-cluster indices and a representative centroid streamline.
fname = get_fnames(name="fornix")
# bbox_valid_check=False is required because the fornix streamlines extend
# slightly outside the reference image bounding box.
fornix = load_tractogram(fname, "same", bbox_valid_check=False)
streamlines = fornix.streamlines
qb = QuickBundles(threshold=10.0)
clusters = qb.cluster(streamlines)
print(f"Number of clusters: {len(clusters)}")
print(f"Cluster sizes: {list(map(len, clusters))}")
Part 1 — TRX: Native Groups Approach#
TRX natively supports two complementary ways to encode clustering results:
dps/clusters_QB— a per-streamline integer label array stored underdata_per_streamline. The key nameclusters_QBis a TRX convention established in the OHBM demo tractogram.groups/cluster_N— named index arrays, one per cluster, stored undergroups/. Each array contains the indices of the streamlines that belong to that cluster, enabling random access without loading all data.
We build a StatefulTractogram with only the original streamlines and the
clusters_QB label array, convert it to a TrxFile, then attach the
per-cluster index arrays as named groups before saving.
labels = np.empty(len(streamlines), dtype=np.int16)
for i, cluster in enumerate(clusters):
for idx in cluster.indices:
labels[idx] = i
sft = StatefulTractogram(
streamlines,
reference=fornix,
space=Space.RASMM,
data_per_streamline={"clusters_QB": labels},
)
trx = tmm.TrxFile.from_sft(sft)
for i, cluster in enumerate(clusters):
trx.groups[f"cluster_{i}"] = np.array(cluster.indices, dtype=np.uint32)
tmm.save(trx, "clustering_results.trx")
trx.close()
Reloading the TRX File and Accessing Groups#
After reloading, data_per_streamline["clusters_QB"] gives the per-
streamline label and groups["cluster_N"] gives the index array for each
named cluster — no concatenation or sentinel flags required.
trx_loaded = tmm.load("clustering_results.trx")
print(f"Loaded {len(trx_loaded.streamlines)} streamlines from TRX")
print(f"dps keys: {list(trx_loaded.data_per_streamline.keys())}")
print(f"group keys: {list(trx_loaded.groups.keys())}")
print(
f"labels match: "
f"{np.array_equal(trx_loaded.data_per_streamline['clusters_QB'], labels)}"
)
trx_loaded.close()
Part 2 — TRK: Merging Streamlines and Centroids#
The TRK format has no native groups concept, so the standard approach is
to append the centroid streamlines to the original streamlines and record
two integer arrays in data_per_streamline:
cluster— the cluster index for every streamline (centroids carry their own cluster index).is_centroid— 0 for original streamlines, 1 for centroids.
centroid_labels = np.arange(len(clusters), dtype=np.int32)
all_labels = np.concatenate((labels.ravel().astype(np.int32), centroid_labels))
is_centroid = np.concatenate(
(
np.zeros(len(streamlines), dtype=np.int32),
np.ones(len(clusters), dtype=np.int32),
)
)
combined_streamlines = list(streamlines) + list(clusters.centroids)
sft_combined = StatefulTractogram(
combined_streamlines,
reference=fornix,
space=Space.RASMM,
data_per_streamline={
"cluster": all_labels,
"is_centroid": is_centroid,
},
)
save_tractogram(sft_combined, "clustering_results.trk", bbox_valid_check=False)
sft_trk = load_tractogram("clustering_results.trk", "same", bbox_valid_check=False)
print(f"Loaded {len(sft_trk.streamlines)} total streamlines from TRK")
print(
f"TRK labels match: "
f"{np.array_equal(sft_trk.data_per_streamline['cluster'], all_labels)}"
)
print(
f"TRK centroid flags match: "
f"{np.array_equal(sft_trk.data_per_streamline['is_centroid'], is_centroid)}"
)