import warnings
import numpy as np
from scipy.sparse import coo_array
from scipy.spatial import cKDTree
from import StatefulTractogram
from dipy.segment.metric import mean_euclidean_distance
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import set_number_of_points
class FastStreamlineSearch:
def __init__(
"""Fast Streamline Search (FFS)
Generate the Binned K-D Tree structure with reference streamlines,
using streamlines barycenter and mean-points.
See :footcite:p:`StOnge2022` for further details.
ref_streamlines : Streamlines
Streamlines (ref) to generate the tree structure.
max_radius : float
The maximum radius (distance) for subsequent streamline search.
Used to compute the overlap in-between bins.
nb_mpts : int, optional
Number of means points to improve computation speed.
(this only changes computation time)
bin_size : float, optional
The bin size to separate streamlines in groups.
(this only changes computation time)
resampling : int, optional
Number of points used to reshape each streamline.
bidirectional : bool, optional
Compute the smallest distance with and without flip.
Make sure that streamlines are aligned in the same space.
Preferably in millimeter space (voxmm or rasmm).
.. footbibliography::
if max_radius <= 0.0:
raise ValueError("max_radius needs to be a positive value")
if resampling < 20:
"For accurate results, resampling should be"
" at least >= 10 and preferably >= 20",
if resampling % nb_mpts != 0:
raise ValueError("nb_mpts needs to be a factor of resampling")
if isinstance(ref_streamlines, StatefulTractogram):
ref_streamlines = ref_streamlines.streamlines
self.nb_mpts = nb_mpts
self.bin_size = bin_size
self.bidirectional = bidirectional
self.resampling = resampling
self.max_radius = max_radius
# Resample streamlines
self.ref_slines = self._resample(ref_streamlines)
self.ref_nb_slines = len(self.ref_slines)
if self.bidirectional:
self.ref_slines = np.concatenate(
[self.ref_slines, np.flip(self.ref_slines, axis=1)]
# Compute streamlines barycenter
barycenters = self._slines_barycenters(self.ref_slines)
# Compute bin shape (min, max, shape)
bin_overlap = max_radius
self.min_box = np.min(barycenters, axis=0) - bin_overlap
self.max_box = np.max(barycenters, axis=0) + bin_overlap
box_length = self.max_box - self.min_box
self.bin_shape = (box_length // bin_size).astype(int) + 1
# Compute the center of each bin
bin_list = np.arange(
all_bins = np.vstack(np.unravel_index(bin_list, self.bin_shape)).T
bins_center = all_bins * bin_size + self.min_box + bin_size / 2.0
# Assign a list of streamlines to each bin
baryc_tree = cKDTree(barycenters)
center_dist = bin_size / 2.0 + bin_overlap
baryc_bins = baryc_tree.query_ball_point(bins_center, center_dist, p=np.inf)
# Compute streamlines mean-points
meanpts = self._slines_mean_points(self.ref_slines)
# Compute bin indices, streamlines + mean-points tree
self.bin_dict = {}
for i, baryc_b in enumerate(baryc_bins):
if baryc_b:
slines_id = np.asarray(baryc_b)
self.bin_dict[i] = (slines_id, cKDTree(meanpts[slines_id]))
def radius_search(self, streamlines, radius, *, use_negative=True):
"""Radius Search using Fast Streamline Search
For each given streamlines, return all reference streamlines
within the given radius. See :footcite:p:`StOnge2022` for further
streamlines : Streamlines
Streamlines to generate the tree structure.
radius : float
Search radius (with MDF / average L2 distance)
must be smaller than max_radius when FFS was initialized.
use_negative : bool, optional
When used with bidirectional,
negative values are returned for reversed order neighbors.
res : scipy COOrdinates sparse matrix (nb_slines x nb_slines_ref)
Adjacency matrix containing all neighbors within the given radius
Given streamlines should be already aligned with ref streamlines.
Preferably in millimeter space (voxmm or rasmm).
.. footbibliography::
if radius > self.max_radius:
raise ValueError(
"radius should be smaller or equal to the given"
"\n 'max_radius' in FastStreamlineSearch init"
if isinstance(streamlines, StatefulTractogram):
streamlines = streamlines.streamlines
# Resample query streamlines
q_slines = self._resample(streamlines)
q_nb_slines = len(q_slines)
# Compute streamlines barycenter
q_baryc = self._slines_barycenters(q_slines)
# Verify if each barycenter are inside the min max box
u_bin, binned_slines_ids = self._barycenters_binning(q_baryc)
# Adapting radius for L1 query: sqrt(3) = 1.73205080756887729..
# Rounded up for float32 precision to avoid error / false negative
l1_sum_dist = 1.73205081 * radius * self.nb_mpts
# Search for all similar streamlines
list_id = []
list_id_ref = []
list_dist = []
for i, bin_id in enumerate(u_bin):
if bin_id in self.bin_dict:
slines_id_ref, ref_tree = self.bin_dict[bin_id]
slines_id = binned_slines_ids[i]
mpts = self._slines_mean_points(q_slines[slines_id])
# Compute Tree L1 Query with mean-points
res = ref_tree.query_ball_point(mpts, l1_sum_dist, p=1)
# Refine distance with the complete
for s, ref_ids in enumerate(res):
if ref_ids:
s_id = slines_id[s]
rs_ids = slines_id_ref[ref_ids]
d = mean_euclidean_distance(
q_slines[s_id], self.ref_slines[rs_ids]
# Return all pairs within the radius
in_dist_max = d < radius
id_ref = rs_ids[in_dist_max]
id_s = np.full_like(id_ref, s_id)
# Combine all results in a coup sparse matrix
if len(list_id) > 0:
ids_in = np.hstack(list_id)
ids_ref = np.hstack(list_id_ref)
dist = np.hstack(list_dist)
if self.bidirectional:
flipped = ids_ref >= self.ref_nb_slines
ids_ref[flipped] -= self.ref_nb_slines
if use_negative:
dist[flipped] *= -1.0
return coo_array(
(dist, (ids_in, ids_ref)), shape=(q_nb_slines, self.ref_nb_slines)
# No results, return an empty sparse matrix
return coo_array((q_nb_slines, self.ref_nb_slines))
def _resample(self, streamlines):
"""Resample streamlines"""
s = np.zeros([len(streamlines), self.resampling, 3], dtype=np.float32)
for i, sline in enumerate(streamlines):
if len(sline) < 2:
s[i] = sline
s[i] = set_number_of_points(sline, nb_points=self.resampling)
return s
def _slines_barycenters(self, slines_arr):
"""Compute streamlines barycenter"""
return np.mean(slines_arr, axis=1)
def _slines_mean_points(self, slines_arr):
"""Compute streamlines mean-points"""
r_arr = slines_arr.reshape((len(slines_arr), self.nb_mpts, -1, 3))
mpts = np.mean(r_arr, axis=2)
return mpts.reshape(len(slines_arr), -1)
def _barycenters_binning(self, barycenters):
"""Bin indices in a list according to their barycenter position"""
in_bin = np.logical_and(
np.all(barycenters >= self.min_box, axis=1),
np.all(barycenters <= self.max_box, axis=1),
baryc_to_box = barycenters[in_bin] - self.min_box
baryc_bins_id = (baryc_to_box // self.bin_size).astype(int)
baryc_multiid = np.ravel_multi_index(baryc_bins_id.T, self.bin_shape)
sort_id = np.argsort(baryc_multiid)
u_bin, mapping = np.unique(baryc_multiid[sort_id], return_index=True)
slines_ids = np.split(np.flatnonzero(in_bin)[sort_id], mapping[1:])
return u_bin, slines_ids
def nearest_from_matrix_row(coo_array):
Return the nearest (smallest) for each row given an coup sparse matrix
coo_array : scipy COOrdinates sparse array (nb_slines x nb_slines_ref)
Adjacency matrix containing all neighbors within the given radius
non_zero_ids : numpy array (nb_non_empty_row x 1)
Indices of each non-empty slines (row)
nearest_id : numpy array (nb_non_empty_row x 1)
Indices of the nearest reference match (column)
nearest_dist : numpy array (nb_non_empty_row x 1)
Distance for each nearest match
non_zero_ids = np.unique(coo_array.row)
sparse_matrix = np.abs(coo_array.tocsr())
upper_limit = np.max( + 1.0 = upper_limit -
nearest_id = np.squeeze(sparse_matrix.argmax(axis=1).data)[non_zero_ids]
nearest_dist = upper_limit - np.squeeze(sparse_matrix.max(axis=1).data)
return non_zero_ids, nearest_id, nearest_dist
def nearest_from_matrix_col(coo_array):
Return the nearest (smallest) for each column given an coup sparse matrix
coo_array : scipy COOrdinates sparse matrix (nb_slines x nb_slines_ref)
Adjacency matrix containing all neighbors within the given radius
non_zero_ids : numpy array (nb_non_empty_col x 1)
Indices of each non-empty reference (column)
nearest_id : numpy array (nb_non_empty_col x 1)
Indices of the nearest slines match (row)
nearest_dist : numpy array (nb_non_empty_col x 1)
Distance for each nearest match
non_zero_ids = np.unique(coo_array.col)
sparse_matrix = np.abs(coo_array.tocsc())
upper_limit = np.max( + 1.0 = upper_limit -
nearest_id = np.squeeze(sparse_matrix.argmax(axis=0).data)[non_zero_ids]
nearest_dist = upper_limit - np.squeeze(sparse_matrix.max(axis=0).data)
return non_zero_ids, nearest_id, nearest_dist