from itertools import chain
import logging
from time import time
from nibabel.affines import apply_affine
import numpy as np
from dipy.align.streamlinear import (
BundleMinDistanceAsymmetricMetric,
BundleMinDistanceMetric,
BundleSumDistanceMatrixMetric,
StreamlineLinearRegistration,
)
from dipy.segment.clustering import qbx_and_merge
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.distances import bundles_distances_mam, bundles_distances_mdf
from dipy.tracking.streamline import (
Streamlines,
length,
nbytes,
select_random_set_of_streamlines,
set_number_of_points,
)
[docs]
def check_range(streamline, gt, lt):
length_s = length(streamline)
if (length_s > gt) & (length_s < lt):
return True
else:
return False
logger = logging.getLogger(__name__)
[docs]
def bundle_adjacency(dtracks0, dtracks1, threshold):
"""Find bundle adjacency between two given tracks/bundles
See :footcite:p:`Garyfallidis2012a` for further details about the method.
Parameters
----------
dtracks0 : Streamlines
White matter tract from one subject
dtracks1 : Streamlines
White matter tract from another subject
threshold : float
Threshold controls
how much strictness user wants while calculating bundle adjacency
between two bundles. Smaller threshold means bundles should be strictly
adjacent to get higher BA score.
Returns
-------
res : Float
Bundle adjacency score between two tracts
References
----------
.. footbibliography::
"""
d01 = bundles_distances_mdf(dtracks0, dtracks1)
pair12 = []
for i in range(len(dtracks0)):
if np.min(d01[i, :]) < threshold:
j = np.argmin(d01[i, :])
pair12.append((i, j))
pair12 = np.array(pair12)
pair21 = []
# solo2 = []
for i in range(len(dtracks1)):
if np.min(d01[:, i]) < threshold:
j = np.argmin(d01[:, i])
pair21.append((i, j))
pair21 = np.array(pair21)
A = len(pair12) / float(len(dtracks0))
B = len(pair21) / float(len(dtracks1))
res = 0.5 * (A + B)
return res
[docs]
@warning_for_keywords()
def ba_analysis(recognized_bundle, expert_bundle, *, nb_pts=20, threshold=6.0):
"""Calculates bundle adjacency score between two given bundles
See :footcite:p:`Garyfallidis2012a` for further details about the method.
Parameters
----------
recognized_bundle : Streamlines
Extracted bundle from the whole brain tractogram (eg: AF_L)
expert_bundle : Streamlines
Model bundle used as reference while extracting similar type bundle
from input tractogram
nb_pts : integer, optional
Discretizing streamlines to have nb_pts number of points
threshold : float, optional
Threshold used for in computing bundle adjacency. Threshold controls
how much strictness user wants while calculating bundle adjacency
between two bundles. Smaller threshold means bundles should be strictly
adjacent to get higher BA score.
Returns
-------
Bundle adjacency score between two tracts
References
----------
.. footbibliography::
"""
recognized_bundle = set_number_of_points(recognized_bundle, nb_points=nb_pts)
expert_bundle = set_number_of_points(expert_bundle, nb_points=nb_pts)
return bundle_adjacency(recognized_bundle, expert_bundle, threshold)
[docs]
@warning_for_keywords()
def cluster_bundle(bundle, clust_thr, rng, *, nb_pts=20, select_randomly=500000):
"""Clusters bundles
See :footcite:p:`Garyfallidis2012a` for further details about the method.
Parameters
----------
bundle : Streamlines
White matter tract
clust_thr : float
clustering threshold used in quickbundlesX
rng : np.random.Generator
numpy's random generator for generating random values.
nb_pts: integer, optional
Discretizing streamlines to have nb_points number of points
select_randomly: integer, optional
Randomly select streamlines from the input bundle
Returns
-------
centroids : Streamlines
clustered centroids of the input bundle
References
----------
.. footbibliography::
"""
model_cluster_map = qbx_and_merge(
bundle, clust_thr, nb_pts=nb_pts, select_randomly=select_randomly, rng=rng
)
centroids = model_cluster_map.centroids
return centroids
[docs]
@warning_for_keywords()
def bundle_shape_similarity(
bundle1, bundle2, rng, *, clust_thr=(5, 3, 1.5), threshold=6
):
"""Calculates bundle shape similarity between two given bundles using
bundle adjacency (BA) metric
See :footcite:p:`Garyfallidis2012a`, :footcite:p:`Chandio2020a` for further
details about the method.
Parameters
----------
bundle1 : Streamlines
White matter tract from one subject (eg: AF_L)
bundle2 : Streamlines
White matter tract from another subject (eg: AF_L)
rng : np.random.Generator
Random number generator.
clust_thr : array-like, optional
list of clustering thresholds used in quickbundlesX
threshold : float, optional
Threshold used for in computing bundle adjacency. Threshold controls
how much strictness user wants while calculating shape similarity
between two bundles. Smaller threshold means bundles should be strictly
similar to get higher shape similarity score.
Returns
-------
ba_value : float
Bundle similarity score between two tracts
References
----------
.. footbibliography::
"""
if len(bundle1) == 0 or len(bundle2) == 0:
return 0
bundle1_centroids = cluster_bundle(bundle1, clust_thr=clust_thr, rng=rng)
bundle2_centroids = cluster_bundle(bundle2, clust_thr=clust_thr, rng=rng)
bundle1_centroids = Streamlines(bundle1_centroids)
bundle2_centroids = Streamlines(bundle2_centroids)
ba_value = ba_analysis(
recognized_bundle=bundle1_centroids,
expert_bundle=bundle2_centroids,
threshold=threshold,
)
return ba_value
[docs]
class RecoBundles:
@warning_for_keywords()
def __init__(
self,
streamlines,
*,
greater_than=50,
less_than=1000000,
cluster_map=None,
clust_thr=15,
nb_pts=20,
rng=None,
verbose=False,
):
"""Recognition of bundles
Extract bundles from a participants' tractograms using model bundles
segmented from a different subject or an atlas of bundles.
See :footcite:p:`Garyfallidis2018` for the details.
Parameters
----------
streamlines : Streamlines
The tractogram in which you want to recognize bundles.
greater_than : int, optional
Keep streamlines that have length greater than
this value (default 50)
less_than : int, optional
Keep streamlines have length less than this value (default 1000000)
cluster_map : QB map, optional.
Provide existing clustering to start RB faster (default None).
clust_thr : float, optional.
Distance threshold in mm for clustering `streamlines`.
Default: 15.
nb_pts : int, optional.
Number of points per streamline (default 20)
rng : np.random.Generator
If None define generator in initialization function.
Default: None
verbose: bool, optional.
If True, log information.
Notes
-----
Make sure that before creating this class that the streamlines and
the model bundles are roughly in the same space.
Also default thresholds are assumed in RAS 1mm^3 space. You may
want to adjust those if your streamlines are not in world coordinates.
References
----------
.. footbibliography::
"""
map_ind = np.zeros(len(streamlines))
for i in range(len(streamlines)):
map_ind[i] = check_range(streamlines[i], greater_than, less_than)
map_ind = map_ind.astype(bool)
self.orig_indices = np.array(list(range(0, len(streamlines))))
self.filtered_indices = np.array(self.orig_indices[map_ind])
self.streamlines = Streamlines(streamlines[map_ind])
self.nb_streamlines = len(self.streamlines)
self.verbose = verbose
if self.verbose:
logger.info(f"target brain streamlines length = {len(streamlines)}")
logger.info(
f"After refining target brain streamlines"
f" length = {len(self.streamlines)}"
)
self.start_thr = [40, 25, 20]
if rng is None:
self.rng = np.random.default_rng()
else:
self.rng = rng
if cluster_map is None:
self._cluster_streamlines(clust_thr=clust_thr, nb_pts=nb_pts)
else:
if self.verbose:
t = time()
self.cluster_map = cluster_map
self.cluster_map.refdata = self.streamlines
self.centroids = self.cluster_map.centroids
self.nb_centroids = len(self.centroids)
self.indices = [cluster.indices for cluster in self.cluster_map]
if self.verbose:
logger.info(f" Streamlines have {self.nb_centroids} centroids")
logger.info(f" Total loading duration {time() - t:0.3f} s\n")
def _cluster_streamlines(self, clust_thr, nb_pts):
if self.verbose:
t = time()
logger.info("# Cluster streamlines using QBx")
logger.info(f" Tractogram has {len(self.streamlines)} streamlines")
logger.info(f" Size is {nbytes(self.streamlines):0.3f} MB")
logger.info(f" Distance threshold {clust_thr:0.3f}")
# TODO this needs to become a default parameter
thresholds = self.start_thr + [clust_thr]
merged_cluster_map = qbx_and_merge(
self.streamlines,
thresholds,
nb_pts=nb_pts,
select_randomly=None,
rng=self.rng,
verbose=self.verbose,
)
self.cluster_map = merged_cluster_map
self.centroids = merged_cluster_map.centroids
self.nb_centroids = len(self.centroids)
self.indices = [cluster.indices for cluster in self.cluster_map]
if self.verbose:
logger.info(f" Streamlines have {self.nb_centroids} centroids")
logger.info(f" Total duration {time() - t:0.3f} s\n")
[docs]
@warning_for_keywords()
def recognize(
self,
model_bundle,
model_clust_thr,
*,
reduction_thr=10,
reduction_distance="mdf",
slr=True,
num_threads=None,
slr_metric=None,
slr_x0=None,
slr_bounds=None,
slr_select=(400, 600),
slr_method="L-BFGS-B",
pruning_thr=5,
pruning_distance="mdf",
):
"""Recognize the model_bundle in self.streamlines
See :footcite:p:`Garyfallidis2018` for further details about the method.
Parameters
----------
model_bundle : Streamlines
model bundle streamlines used as a reference to extract similar
streamlines from input tractogram
model_clust_thr : float
MDF distance threshold for the model bundles
reduction_thr : float, optional
Reduce search space in the target tractogram by (mm) (default 10)
reduction_distance : string, optional
Reduction distance type can be mdf or mam (default mdf)
slr : bool, optional
Use Streamline-based Linear Registration (SLR) locally
(default True)
num_threads : int, optional
Number of threads to be used for OpenMP parallelization. If None
(default) the value of OMP_NUM_THREADS environment variable is used
if it is set, otherwise all available threads are used. If < 0 the
maximal number of threads minus $|num_threads + 1|$ is used (enter
-1 to use as many threads as possible). 0 raises an error.
slr_metric : BundleMinDistanceMetric
slr_x0 : array or int or str, optional
Transformation allowed. translation, rigid, similarity or scaling
Initial parametrization for the optimization.
If 1D array with:
a) 6 elements then only rigid registration is performed with
the 3 first elements for translation and 3 for rotation.
b) 7 elements also isotropic scaling is performed (similarity).
c) 12 elements then translation, rotation (in degrees),
scaling and shearing are performed (affine).
Here is an example of x0 with 12 elements:
``x0=np.array([0, 10, 0, 40, 0, 0, 2., 1.5, 1, 0.1, -0.5, 0])``
This has translation (0, 10, 0), rotation (40, 0, 0) in
degrees, scaling (2., 1.5, 1) and shearing (0.1, -0.5, 0).
If int:
a) 6
``x0 = np.array([0, 0, 0, 0, 0, 0])``
b) 7
``x0 = np.array([0, 0, 0, 0, 0, 0, 1.])``
c) 12
``x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])``
If str:
a) "rigid"
``x0 = np.array([0, 0, 0, 0, 0, 0])``
b) "similarity"
``x0 = np.array([0, 0, 0, 0, 0, 0, 1.])``
c) "affine"
``x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])``
slr_bounds : array, optional
SLR bounds.
slr_select : tuple, optional
Select the number of streamlines from model to neighborhood of
model to perform the local SLR.
slr_method : string, optional
Optimization method 'L_BFGS_B' or 'Powell' optimizers can be used.
(default 'L-BFGS-B')
pruning_thr : float, optional
Pruning after reducing the search space.
pruning_distance : string, optional
Pruning distance type can be mdf or mam.
Returns
-------
recognized_transf : Streamlines
Recognized bundle in the space of the model tractogram
recognized_labels : array
Indices of recognized bundle in the original tractogram
References
----------
.. footbibliography::
"""
if self.verbose:
t = time()
logger.info("## Recognize given bundle ## \n")
model_centroids = self._cluster_model_bundle(
model_bundle, model_clust_thr=model_clust_thr
)
neighb_streamlines, neighb_indices = self._reduce_search_space(
model_centroids,
reduction_thr=reduction_thr,
reduction_distance=reduction_distance,
)
if len(neighb_streamlines) == 0:
return Streamlines([]), []
if slr:
transf_streamlines, slr1_bmd = self._register_neighb_to_model(
model_bundle,
neighb_streamlines,
metric=slr_metric,
x0=slr_x0,
bounds=slr_bounds,
select_model=slr_select[0],
select_target=slr_select[1],
method=slr_method,
num_threads=num_threads,
)
else:
transf_streamlines = neighb_streamlines
pruned_streamlines, labels = self._prune_what_not_in_model(
model_centroids,
transf_streamlines,
neighb_indices,
pruning_thr=pruning_thr,
pruning_distance=pruning_distance,
)
if self.verbose:
logger.info(
f"Total duration of recognition time" f" is {time()-t:0.3f} s\n"
)
return pruned_streamlines, self.filtered_indices[labels]
[docs]
@warning_for_keywords()
def refine(
self,
model_bundle,
pruned_streamlines,
model_clust_thr,
*,
reduction_thr=14,
reduction_distance="mdf",
slr=True,
slr_metric=None,
slr_x0=None,
slr_bounds=None,
slr_select=(400, 600),
slr_method="L-BFGS-B",
pruning_thr=6,
pruning_distance="mdf",
):
"""Refine and recognize the model_bundle in self.streamlines
This method expects once pruned streamlines as input. It refines the
first output of RecoBundles by applying second local slr (optional),
and second pruning. This method is useful when we are dealing with
noisy data or when we want to extract small tracks from tractograms.
This time, search space is created using pruned bundle and not model
bundle.
See :footcite:p:`Garyfallidis2018`, :footcite:p:`Chandio2020a` for
further details about the method.
Parameters
----------
model_bundle : Streamlines
model bundle streamlines used as a reference to extract similar
streamlines from input tractogram
pruned_streamlines : Streamlines
Recognized bundle from target tractogram by RecoBundles.
model_clust_thr : float
MDF distance threshold for the model bundles
reduction_thr : float
Reduce search space by (mm) (default 14)
reduction_distance : string
Reduction distance type can be mdf or mam (default mdf)
slr : bool
Use Streamline-based Linear Registration (SLR) locally.
slr_metric : BundleMinDistanceMetric
Bundle distance metric.
slr_x0 : array or int or str
Transformation allowed. translation, rigid, similarity or scaling
Initial parametrization for the optimization.
If 1D array with:
a) 6 elements then only rigid registration is performed with
the 3 first elements for translation and 3 for rotation.
b) 7 elements also isotropic scaling is performed (similarity).
c) 12 elements then translation, rotation (in degrees),
scaling and shearing are performed (affine).
Here is an example of x0 with 12 elements:
``x0=np.array([0, 10, 0, 40, 0, 0, 2., 1.5, 1, 0.1, -0.5, 0])``
This has translation (0, 10, 0), rotation (40, 0, 0) in
degrees, scaling (2., 1.5, 1) and shearing (0.1, -0.5, 0).
If int:
a) 6
``x0 = np.array([0, 0, 0, 0, 0, 0])``
b) 7
``x0 = np.array([0, 0, 0, 0, 0, 0, 1.])``
c) 12
``x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])``
If str:
a) "rigid"
``x0 = np.array([0, 0, 0, 0, 0, 0])``
b) "similarity"
``x0 = np.array([0, 0, 0, 0, 0, 0, 1.])``
c) "affine"
``x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0])``
slr_bounds : array
SLR bounds.
slr_select : tuple
Select the number of streamlines from model to neighborhood of
model to perform the local SLR.
slr_method : string
Optimization method 'L_BFGS_B' or 'Powell' optimizers can be used.
pruning_thr : float
Pruning after reducing the search space.
pruning_distance : string
Pruning distance type can be mdf or mam.
Returns
-------
recognized_transf : Streamlines
Recognized bundle in the space of the model tractogram
recognized_labels : array
Indices of recognized bundle in the original tractogram
References
----------
.. footbibliography::
"""
if self.verbose:
t = time()
logger.info("## Refine recognize given bundle ## \n")
model_centroids = self._cluster_model_bundle(
model_bundle, model_clust_thr=model_clust_thr
)
pruned_model_centroids = self._cluster_model_bundle(
pruned_streamlines, model_clust_thr=model_clust_thr
)
neighb_streamlines, neighb_indices = self._reduce_search_space(
pruned_model_centroids,
reduction_thr=reduction_thr,
reduction_distance=reduction_distance,
)
if len(neighb_streamlines) == 0: # if no streamlines recognized
return Streamlines([]), []
if self.verbose:
logger.info("2nd local Slr")
if slr:
transf_streamlines, slr2_bmd = self._register_neighb_to_model(
model_bundle,
neighb_streamlines,
metric=slr_metric,
x0=slr_x0,
bounds=slr_bounds,
select_model=slr_select[0],
select_target=slr_select[1],
method=slr_method,
)
if self.verbose:
logger.info("pruning after 2nd local Slr")
pruned_streamlines, labels = self._prune_what_not_in_model(
model_centroids,
transf_streamlines,
neighb_indices,
pruning_thr=pruning_thr,
pruning_distance=pruning_distance,
)
if self.verbose:
logger.info(
f"Total duration of recognition time" f" is {time()-t:0.3f} s\n"
)
return pruned_streamlines, self.filtered_indices[labels]
[docs]
def evaluate_results(self, model_bundle, pruned_streamlines, slr_select):
"""Compare the similarity between two given bundles, model bundle,
and extracted bundle.
Parameters
----------
model_bundle : Streamlines
Model bundle streamlines.
pruned_streamlines : Streamlines
Pruned bundle streamlines.
slr_select : tuple
Select the number of streamlines from model to neighborhood of
model to perform the local SLR.
Returns
-------
ba_value : float
bundle adjacency value between model bundle and pruned bundle
bmd_value : float
bundle minimum distance value between model bundle and
pruned bundle
"""
spruned_streamlines = Streamlines(pruned_streamlines)
recog_centroids = self._cluster_model_bundle(
spruned_streamlines, model_clust_thr=1.25
)
mod_centroids = self._cluster_model_bundle(model_bundle, model_clust_thr=1.25)
recog_centroids = Streamlines(recog_centroids)
model_centroids = Streamlines(mod_centroids)
ba_value = bundle_adjacency(
set_number_of_points(recog_centroids, nb_points=20),
set_number_of_points(model_centroids, nb_points=20),
threshold=10,
)
BMD = BundleMinDistanceMetric()
static = select_random_set_of_streamlines(model_bundle, slr_select[0])
moving = select_random_set_of_streamlines(pruned_streamlines, slr_select[1])
nb_pts = 20
static = set_number_of_points(static, nb_points=nb_pts)
moving = set_number_of_points(moving, nb_points=nb_pts)
BMD.setup(static, moving)
x0 = np.array([0, 0, 0, 0, 0, 0, 1.0, 1.0, 1, 0, 0, 0]) # affine
bmd_value = BMD.distance(x0.tolist())
return ba_value, bmd_value
@warning_for_keywords()
def _cluster_model_bundle(
self, model_bundle, model_clust_thr, *, nb_pts=20, select_randomly=500000
):
if self.verbose:
t = time()
logger.info("# Cluster model bundle using QBX")
logger.info(f" Model bundle has {len(model_bundle)} streamlines")
logger.info(f" Distance threshold {model_clust_thr:0.3f}")
thresholds = self.start_thr + [model_clust_thr]
model_cluster_map = qbx_and_merge(
model_bundle,
thresholds,
nb_pts=nb_pts,
select_randomly=select_randomly,
rng=self.rng,
)
model_centroids = model_cluster_map.centroids
nb_model_centroids = len(model_centroids)
if self.verbose:
logger.info(f" Model bundle has {nb_model_centroids} centroids")
logger.info(f" Duration {time() - t:0.3f} s\n")
return model_centroids
@warning_for_keywords()
def _reduce_search_space(
self, model_centroids, *, reduction_thr=20, reduction_distance="mdf"
):
if self.verbose:
t = time()
logger.info("# Reduce search space")
logger.info(f" Reduction threshold {reduction_thr:0.3f}")
logger.info(f" Reduction distance {reduction_distance}")
if reduction_distance.lower() == "mdf":
if self.verbose:
logger.info(" Using MDF")
centroid_matrix = bundles_distances_mdf(model_centroids, self.centroids)
elif reduction_distance.lower() == "mam":
if self.verbose:
logger.info(" Using MAM")
centroid_matrix = bundles_distances_mam(model_centroids, self.centroids)
else:
raise ValueError("Given reduction distance not known")
centroid_matrix[centroid_matrix > reduction_thr] = np.inf
mins = np.min(centroid_matrix, axis=0)
close_clusters_indices = list(np.where(mins != np.inf)[0])
close_clusters = self.cluster_map[close_clusters_indices]
neighb_indices = [cluster.indices for cluster in close_clusters]
neighb_streamlines = Streamlines(chain(*close_clusters))
nb_neighb_streamlines = len(neighb_streamlines)
if nb_neighb_streamlines == 0:
if self.verbose:
logger.info("You have no neighbor streamlines... No bundle recognition")
return Streamlines([]), []
if self.verbose:
logger.info(f" Number of neighbor streamlines" f" {nb_neighb_streamlines}")
logger.info(f" Duration {time() - t:0.3f} s\n")
return neighb_streamlines, neighb_indices
@warning_for_keywords()
def _register_neighb_to_model(
self,
model_bundle,
neighb_streamlines,
*,
metric=None,
x0=None,
bounds=None,
select_model=400,
select_target=600,
method="L-BFGS-B",
nb_pts=20,
num_threads=None,
):
if self.verbose:
logger.info("# Local SLR of neighb_streamlines to model")
t = time()
if metric is None or metric == "symmetric":
metric = BundleMinDistanceMetric(num_threads=num_threads)
if metric == "asymmetric":
metric = BundleMinDistanceAsymmetricMetric()
if metric == "diagonal":
metric = BundleSumDistanceMatrixMetric()
if x0 is None:
x0 = "similarity"
if bounds is None:
bounds = [
(-30, 30),
(-30, 30),
(-30, 30),
(-45, 45),
(-45, 45),
(-45, 45),
(0.8, 1.2),
]
# TODO this can be speeded up by using directly the centroids
static = select_random_set_of_streamlines(
model_bundle, select_model, rng=self.rng
)
moving = select_random_set_of_streamlines(
neighb_streamlines, select_target, rng=self.rng
)
static = set_number_of_points(static, nb_points=nb_pts)
moving = set_number_of_points(moving, nb_points=nb_pts)
slr = StreamlineLinearRegistration(
metric=metric, x0=x0, bounds=bounds, method=method
)
slm = slr.optimize(static, moving)
transf_streamlines = neighb_streamlines.copy()
transf_streamlines._data = apply_affine(slm.matrix, transf_streamlines._data)
transf_matrix = slm.matrix
slr_bmd = slm.fopt
slr_iterations = slm.iterations
if self.verbose:
logger.info(f" Square-root of BMD is {np.sqrt(slr_bmd):.3f}")
if slr_iterations is not None:
logger.info(f" Number of iterations {slr_iterations}")
logger.info(f" Matrix size {slm.matrix.shape}")
original = np.get_printoptions()
np.set_printoptions(3, suppress=True)
logger.info(transf_matrix)
logger.info(slm.xopt)
np.set_printoptions(**original)
logger.info(f" Duration {time() - t:0.3f} s\n")
return transf_streamlines, slr_bmd
@warning_for_keywords()
def _prune_what_not_in_model(
self,
model_centroids,
transf_streamlines,
neighb_indices,
*,
mdf_thr=5,
pruning_thr=10,
pruning_distance="mdf",
):
if self.verbose:
if pruning_thr < 0:
logger.info("Pruning_thr has to be greater or equal to 0")
logger.info("# Prune streamlines using the MDF distance")
logger.info(f" Pruning threshold {pruning_thr:0.3f}")
logger.info(f" Pruning distance {pruning_distance}")
t = time()
thresholds = [40, 30, 20, 10, mdf_thr]
rtransf_cluster_map = qbx_and_merge(
transf_streamlines,
thresholds,
nb_pts=20,
select_randomly=500000,
rng=self.rng,
)
if self.verbose:
logger.info(f" QB Duration {time() - t:0.3f} s\n")
rtransf_centroids = rtransf_cluster_map.centroids
if pruning_distance.lower() == "mdf":
if self.verbose:
logger.info(" Using MDF")
dist_matrix = bundles_distances_mdf(model_centroids, rtransf_centroids)
elif pruning_distance.lower() == "mam":
if self.verbose:
logger.info(" Using MAM")
dist_matrix = bundles_distances_mam(model_centroids, rtransf_centroids)
else:
raise ValueError("Given pruning distance is not available")
dist_matrix[np.isnan(dist_matrix)] = np.inf
dist_matrix[dist_matrix > pruning_thr] = np.inf
pruning_matrix = dist_matrix.copy()
if self.verbose:
logger.info(f" Pruning matrix size is {pruning_matrix.shape}")
mins = np.min(pruning_matrix, axis=0)
pruned_indices = [
rtransf_cluster_map[i].indices for i in np.where(mins != np.inf)[0]
]
pruned_indices = list(chain(*pruned_indices))
idx = np.array(pruned_indices)
if len(idx) == 0:
if self.verbose:
logger.info(" You have removed all streamlines")
return Streamlines([]), []
pruned_streamlines = transf_streamlines[idx]
initial_indices = list(chain(*neighb_indices))
final_indices = [initial_indices[i] for i in pruned_indices]
labels = final_indices
if self.verbose:
logger.info(f" Number of centroids: {len(rtransf_centroids)}")
logger.info(
f" Number of streamlines after pruning:" f" {len(pruned_streamlines)}"
)
logger.info(f" Duration {time() - t:0.3f} s\n")
return pruned_streamlines, labels