Source code for dipy.segment.fss

import warnings

import numpy as np
from scipy.sparse import coo_array
from scipy.spatial import cKDTree

from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.segment.metric import mean_euclidean_distance
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import set_number_of_points


[docs] class FastStreamlineSearch: @warning_for_keywords() def __init__( self, ref_streamlines, max_radius, *, nb_mpts=4, bin_size=20.0, resampling=24, bidirectional=True, ): """Fast Streamline Search (FFS) Generate the Binned K-D Tree structure with reference streamlines, using streamlines barycenter and mean-points. See :footcite:p:`StOnge2022` for further details. Parameters ---------- ref_streamlines : Streamlines Streamlines (ref) to generate the tree structure. max_radius : float The maximum radius (distance) for subsequent streamline search. Used to compute the overlap in-between bins. nb_mpts : int, optional Number of means points to improve computation speed. (this only changes computation time) bin_size : float, optional The bin size to separate streamlines in groups. (this only changes computation time) resampling : int, optional Number of points used to reshape each streamline. bidirectional : bool, optional Compute the smallest distance with and without flip. Notes ----- Make sure that streamlines are aligned in the same space. Preferably in millimeter space (voxmm or rasmm). References ---------- .. footbibliography:: """ if max_radius <= 0.0: raise ValueError("max_radius needs to be a positive value") if resampling < 20: warnings.warn( "For accurate results, resampling should be" " at least >= 10 and preferably >= 20", stacklevel=2, ) if resampling % nb_mpts != 0: raise ValueError("nb_mpts needs to be a factor of resampling") if isinstance(ref_streamlines, StatefulTractogram): ref_streamlines = ref_streamlines.streamlines self.nb_mpts = nb_mpts self.bin_size = bin_size self.bidirectional = bidirectional self.resampling = resampling self.max_radius = max_radius # Resample streamlines self.ref_slines = self._resample(ref_streamlines) self.ref_nb_slines = len(self.ref_slines) if self.bidirectional: self.ref_slines = np.concatenate( [self.ref_slines, np.flip(self.ref_slines, axis=1)] ) # Compute streamlines barycenter barycenters = self._slines_barycenters(self.ref_slines) # Compute bin shape (min, max, shape) bin_overlap = max_radius self.min_box = np.min(barycenters, axis=0) - bin_overlap self.max_box = np.max(barycenters, axis=0) + bin_overlap box_length = self.max_box - self.min_box self.bin_shape = (box_length // bin_size).astype(int) + 1 # Compute the center of each bin bin_list = np.arange(np.prod(self.bin_shape)) all_bins = np.vstack(np.unravel_index(bin_list, self.bin_shape)).T bins_center = all_bins * bin_size + self.min_box + bin_size / 2.0 # Assign a list of streamlines to each bin baryc_tree = cKDTree(barycenters) center_dist = bin_size / 2.0 + bin_overlap baryc_bins = baryc_tree.query_ball_point(bins_center, center_dist, p=np.inf) # Compute streamlines mean-points meanpts = self._slines_mean_points(self.ref_slines) # Compute bin indices, streamlines + mean-points tree self.bin_dict = {} for i, baryc_b in enumerate(baryc_bins): if baryc_b: slines_id = np.asarray(baryc_b) self.bin_dict[i] = (slines_id, cKDTree(meanpts[slines_id])) def _resample(self, streamlines): """Resample streamlines""" s = np.zeros([len(streamlines), self.resampling, 3], dtype=np.float32) for i, sline in enumerate(streamlines): if len(sline) < 2: s[i] = sline else: s[i] = set_number_of_points(sline, nb_points=self.resampling) return s def _slines_barycenters(self, slines_arr): """Compute streamlines barycenter""" return np.mean(slines_arr, axis=1) def _slines_mean_points(self, slines_arr): """Compute streamlines mean-points""" r_arr = slines_arr.reshape((len(slines_arr), self.nb_mpts, -1, 3)) mpts = np.mean(r_arr, axis=2) return mpts.reshape(len(slines_arr), -1) def _barycenters_binning(self, barycenters): """Bin indices in a list according to their barycenter position""" in_bin = np.logical_and( np.all(barycenters >= self.min_box, axis=1), np.all(barycenters <= self.max_box, axis=1), ) baryc_to_box = barycenters[in_bin] - self.min_box baryc_bins_id = (baryc_to_box // self.bin_size).astype(int) baryc_multiid = np.ravel_multi_index(baryc_bins_id.T, self.bin_shape) sort_id = np.argsort(baryc_multiid) u_bin, mapping = np.unique(baryc_multiid[sort_id], return_index=True) slines_ids = np.split(np.flatnonzero(in_bin)[sort_id], mapping[1:]) return u_bin, slines_ids
[docs] def nearest_from_matrix_row(coo_array): """ Return the nearest (smallest) for each row given an coup sparse matrix Parameters ---------- coo_array : scipy COOrdinates sparse array (nb_slines x nb_slines_ref) Adjacency matrix containing all neighbors within the given radius Returns ------- non_zero_ids : numpy array (nb_non_empty_row x 1) Indices of each non-empty slines (row) nearest_id : numpy array (nb_non_empty_row x 1) Indices of the nearest reference match (column) nearest_dist : numpy array (nb_non_empty_row x 1) Distance for each nearest match """ non_zero_ids = np.unique(coo_array.row) sparse_matrix = np.abs(coo_array.tocsr()) upper_limit = np.max(sparse_matrix.data) + 1.0 sparse_matrix.data = upper_limit - sparse_matrix.data nearest_id = np.squeeze(sparse_matrix.argmax(axis=1).data)[non_zero_ids] nearest_dist = upper_limit - np.squeeze(sparse_matrix.max(axis=1).data) return non_zero_ids, nearest_id, nearest_dist
[docs] def nearest_from_matrix_col(coo_array): """ Return the nearest (smallest) for each column given an coup sparse matrix Parameters ---------- coo_array : scipy COOrdinates sparse matrix (nb_slines x nb_slines_ref) Adjacency matrix containing all neighbors within the given radius Returns ------- non_zero_ids : numpy array (nb_non_empty_col x 1) Indices of each non-empty reference (column) nearest_id : numpy array (nb_non_empty_col x 1) Indices of the nearest slines match (row) nearest_dist : numpy array (nb_non_empty_col x 1) Distance for each nearest match """ non_zero_ids = np.unique(coo_array.col) sparse_matrix = np.abs(coo_array.tocsc()) upper_limit = np.max(sparse_matrix.data) + 1.0 sparse_matrix.data = upper_limit - sparse_matrix.data nearest_id = np.squeeze(sparse_matrix.argmax(axis=0).data)[non_zero_ids] nearest_dist = upper_limit - np.squeeze(sparse_matrix.max(axis=0).data) return non_zero_ids, nearest_id, nearest_dist