from abc import ABCMeta, abstractmethod
import logging
import operator
from time import time
import numpy as np
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metricspeed import (
AveragePointwiseEuclideanMetric,
Metric,
MinimumAverageDirectFlipMetric,
)
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import nbytes, set_number_of_points
logger = logging.getLogger(__name__)
[docs]
class Identity:
"""Provides identity indexing functionality.
This can replace any class supporting indexing used for referencing
(e.g. list, tuple). Indexing an instance of this class will return the
index provided instead of the element. It does not support slicing.
"""
def __getitem__(self, idx):
return idx
[docs]
class Cluster:
"""Provides functionalities for interacting with a cluster.
Useful container to retrieve index of elements grouped together. If
a reference to the data is provided to `cluster_map`, elements will
be returned instead of their index when possible.
Parameters
----------
cluster_map : `ClusterMap` object
Reference to the set of clusters this cluster is being part of.
id : int, optional
Id of this cluster in its associated `cluster_map` object.
refdata : list, optional
Actual elements that clustered indices refer to.
Notes
-----
A cluster does not contain actual data but instead knows how to
retrieve them using its `ClusterMap` object.
"""
@warning_for_keywords()
def __init__(self, *, id=0, indices=None, refdata=None):
if refdata is None:
refdata = Identity()
self.id = id
self.refdata = refdata
self.indices = indices if indices is not None else []
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
"""Gets element(s) through indexing.
If a reference to the data was provided (via refdata property)
elements will be returned instead of their index.
Parameters
----------
idx : int, slice or list
Index of the element(s) to get.
Returns
-------
`Cluster` object(s)
When `idx` is a int, returns a single element.
When `idx` is either a slice or a list, returns a list of elements.
"""
if isinstance(idx, (int, np.integer)):
return self.refdata[self.indices[idx]]
elif isinstance(idx, slice):
return [self.refdata[i] for i in self.indices[idx]]
elif isinstance(idx, list):
return [self[i] for i in idx]
msg = f"Index must be a int or a slice! Not '{type(idx)}'"
raise TypeError(msg)
def __iter__(self):
return (self[i] for i in range(len(self)))
def __str__(self):
return "[" + ", ".join(map(str, self.indices)) + "]"
def __repr__(self):
return f"Cluster({str(self)})"
def __eq__(self, other):
return isinstance(other, Cluster) and self.indices == other.indices
def __ne__(self, other):
return not self == other
def __cmp__(self, other):
raise TypeError("Cannot compare Cluster objects.")
[docs]
def assign(self, *indices):
"""Assigns indices to this cluster.
Parameters
----------
*indices : list of indices
Indices to add to this cluster.
"""
self.indices += indices
[docs]
class ClusterCentroid(Cluster):
"""Provides functionalities for interacting with a cluster.
Useful container to retrieve the indices of elements grouped together and
the cluster's centroid. If a reference to the data is provided to
`cluster_map`, elements will be returned instead of their index when
possible.
Parameters
----------
cluster_map : `ClusterMapCentroid` object
Reference to the set of clusters this cluster is being part of.
id : int, optional
Id of this cluster in its associated `cluster_map` object.
refdata : list, optional
Actual elements that clustered indices refer to.
Notes
-----
A cluster does not contain actual data but instead knows how to
retrieve them using its `ClusterMapCentroid` object.
"""
@warning_for_keywords()
def __init__(self, centroid, *, id=0, indices=None, refdata=None):
if refdata is None:
refdata = Identity()
super(ClusterCentroid, self).__init__(id=id, indices=indices, refdata=refdata)
self.centroid = centroid.copy()
self.new_centroid = centroid.copy()
def __eq__(self, other):
return (
isinstance(other, ClusterCentroid)
and np.all(self.centroid == other.centroid)
and super(ClusterCentroid, self).__eq__(other)
)
[docs]
def assign(self, id_datum, features):
"""Assigns a data point to this cluster.
Parameters
----------
id_datum : int
Index of the data point to add to this cluster.
features : 2D array
Data point's features to modify this cluster's centroid.
"""
N = len(self)
self.new_centroid = ((self.new_centroid * N) + features) / (N + 1.0)
super(ClusterCentroid, self).assign(id_datum)
[docs]
def update(self):
"""Update centroid of this cluster.
Returns
-------
converged : bool
Tells if the centroid has moved.
"""
converged = np.equal(self.centroid, self.new_centroid)
self.centroid = self.new_centroid.copy()
return converged
[docs]
class ClusterMap:
"""Provides functionalities for interacting with clustering outputs.
Useful container to create, remove, retrieve and filter clusters.
If `refdata` is given, elements will be returned instead of their
index when using `Cluster` objects.
Parameters
----------
refdata : list
Actual elements that clustered indices refer to.
"""
@warning_for_keywords()
def __init__(self, *, refdata=None):
if refdata is None:
refdata = Identity()
self._clusters = []
self.refdata = refdata
@property
def clusters(self):
return self._clusters
@property
def refdata(self):
return self._refdata
@refdata.setter
def refdata(self, value):
if value is None:
value = Identity()
self._refdata = value
for cluster in self.clusters:
cluster.refdata = self._refdata
def __len__(self):
return len(self.clusters)
def __getitem__(self, idx):
"""Gets cluster(s) through indexing.
Parameters
----------
idx : int, slice, list or boolean array
Index of the element(s) to get.
Returns
-------
`Cluster` object(s)
When `idx` is an int, returns a single `Cluster` object.
When `idx`is either a slice, list or boolean array, returns
a list of `Cluster` objects.
"""
if isinstance(idx, np.ndarray) and idx.dtype == bool:
return [self.clusters[i] for i, take_it in enumerate(idx) if take_it]
elif isinstance(idx, slice):
return [self.clusters[i] for i in range(*idx.indices(len(self)))]
elif isinstance(idx, list):
return [self.clusters[i] for i in idx]
return self.clusters[idx]
def __iter__(self):
return iter(self.clusters)
def __str__(self):
return "[" + ", ".join(map(str, self)) + "]"
def __repr__(self):
return f"ClusterMap({str(self)})"
def _richcmp(self, other, op):
"""Compares this cluster map with another cluster map or an integer.
Two `ClusterMap` objects are equal if they contain the same clusters.
When comparing a `ClusterMap` object with an integer, the comparison
will be performed on the size of the clusters instead.
Parameters
----------
other : `ClusterMap` object or int
Object to compare to.
op : rich comparison operators (see module `operator`)
Valid operators are: lt, le, eq, ne, gt or ge.
Returns
-------
bool or 1D array (bool)
When comparing to another `ClusterMap` object, it returns whether
the two `ClusterMap` objects contain the same clusters or not.
When comparing to an integer the comparison is performed on the
clusters sizes, it returns an array of boolean.
"""
if isinstance(other, ClusterMap):
if op is operator.eq:
return (
isinstance(other, ClusterMap)
and len(self) == len(other)
and self.clusters == other.clusters
)
elif op is operator.ne:
return not self == other
raise NotImplementedError(
"Can only check if two ClusterMap instances are equal or not."
)
elif isinstance(other, int):
return np.array([op(len(cluster), other) for cluster in self])
msg = (
"ClusterMap only supports comparison with a int or another"
" instance of Clustermap."
)
raise NotImplementedError(msg)
def __eq__(self, other):
return self._richcmp(other, operator.eq)
def __ne__(self, other):
return self._richcmp(other, operator.ne)
def __lt__(self, other):
return self._richcmp(other, operator.lt)
def __le__(self, other):
return self._richcmp(other, operator.le)
def __gt__(self, other):
return self._richcmp(other, operator.gt)
def __ge__(self, other):
return self._richcmp(other, operator.ge)
[docs]
def add_cluster(self, *clusters):
"""Adds one or multiple clusters to this cluster map.
Parameters
----------
*clusters : `Cluster` object, ...
Cluster(s) to be added in this cluster map.
"""
for cluster in clusters:
self.clusters.append(cluster)
cluster.refdata = self.refdata
[docs]
def remove_cluster(self, *clusters):
"""Remove one or multiple clusters from this cluster map.
Parameters
----------
*clusters : `Cluster` object, ...
Cluster(s) to be removed from this cluster map.
"""
for cluster in clusters:
self.clusters.remove(cluster)
[docs]
def clear(self):
"""Remove all clusters from this cluster map."""
del self.clusters[:]
[docs]
def size(self):
"""Gets number of clusters contained in this cluster map."""
return len(self)
[docs]
def clusters_sizes(self):
"""Gets the size of every cluster contained in this cluster map.
Returns
-------
list of int
Sizes of every cluster in this cluster map.
"""
return list(map(len, self))
[docs]
def get_large_clusters(self, min_size):
"""Gets clusters which contains at least `min_size` elements.
Parameters
----------
min_size : int
Minimum number of elements a cluster needs to have to be selected.
Returns
-------
list of `Cluster` objects
Clusters having at least `min_size` elements.
"""
return self[self >= min_size]
[docs]
def get_small_clusters(self, max_size):
"""Gets clusters which contains at most `max_size` elements.
Parameters
----------
max_size : int
Maximum number of elements a cluster can have to be selected.
Returns
-------
list of `Cluster` objects
Clusters having at most `max_size` elements.
"""
return self[self <= max_size]
[docs]
class ClusterMapCentroid(ClusterMap):
"""Provides functionalities for interacting with clustering outputs
that have centroids.
Allows to retrieve easily the centroid of every cluster. Also, it is
a useful container to create, remove, retrieve and filter clusters.
If `refdata` is given, elements will be returned instead of their
index when using `ClusterCentroid` objects.
Parameters
----------
refdata : list
Actual elements that clustered indices refer to.
"""
@property
def centroids(self):
return [cluster.centroid for cluster in self.clusters]
[docs]
class Clustering:
__metaclass__ = ABCMeta
[docs]
@abstractmethod
@warning_for_keywords()
def cluster(self, data, *, ordering=None):
"""Clusters `data`.
Subclasses will perform their clustering algorithm here.
Parameters
----------
data : list of N-dimensional arrays
Each array represents a data point.
ordering : iterable of indices, optional
Specifies the order in which data points will be clustered.
Returns
-------
`ClusterMap` object
Result of the clustering.
"""
msg = "Subclass has to define method 'cluster(data, ordering)'!"
raise NotImplementedError(msg)
[docs]
class QuickBundles(Clustering):
r"""Clusters streamlines using QuickBundles.
Given a list of streamlines, the QuickBundles algorithm
:footcite:p:`Garyfallidis2012a` sequentially assigns each streamline to its
closest bundle in $\mathcal{O}(Nk)$ where $N$ is the number of streamlines
and $k$ is the final number of bundles. If for a given streamline its
closest bundle is farther than `threshold`, a new bundle is created and the
streamline is assigned to it except if the number of bundles has already
exceeded `max_nb_clusters`.
Parameters
----------
threshold : float
The maximum distance from a bundle for a streamline to be still
considered as part of it.
metric : str or `Metric` object, optional
The distance metric to use when comparing two streamlines. By default,
the Minimum average Direct-Flip (MDF) distance
:footcite:p:`Garyfallidis2012a` is used and streamlines are
automatically resampled so they have 12 points.
max_nb_clusters : int, optional
Limits the creation of bundles.
Examples
--------
>>> from dipy.segment.clustering import QuickBundles
>>> from dipy.data import get_fnames
>>> from dipy.io.streamline import load_tractogram
>>> from dipy.tracking.streamline import Streamlines
>>> fname = get_fnames(name='fornix')
>>> fornix = load_tractogram(fname, 'same',
... bbox_valid_check=False).streamlines
>>> streamlines = Streamlines(fornix)
>>> # Segment fornix with a threshold of 10mm and streamlines resampled
>>> # to 12 points.
>>> qb = QuickBundles(threshold=10.)
>>> clusters = qb.cluster(streamlines)
>>> len(clusters)
4
>>> list(map(len, clusters))
[61, 191, 47, 1]
>>> # Resampling streamlines differently is done explicitly as follows.
>>> # Note this has an impact on the speed and the accuracy (tradeoff).
>>> from dipy.segment.featurespeed import ResampleFeature
>>> from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
>>> feature = ResampleFeature(nb_points=2)
>>> metric = AveragePointwiseEuclideanMetric(feature)
>>> qb = QuickBundles(threshold=10., metric=metric)
>>> clusters = qb.cluster(streamlines)
>>> len(clusters)
4
>>> list(map(len, clusters))
[58, 142, 72, 28]
References
----------
.. footbibliography::
"""
@warning_for_keywords()
def __init__(self, threshold, *, metric="MDF_12points", max_nb_clusters=None):
if max_nb_clusters is None:
max_nb_clusters = np.iinfo("i4").max
self.threshold = threshold
self.max_nb_clusters = max_nb_clusters
if isinstance(metric, MinimumAverageDirectFlipMetric):
raise ValueError("Use AveragePointwiseEuclideanMetric instead")
if isinstance(metric, Metric):
self.metric = metric
elif metric == "MDF_12points":
feature = ResampleFeature(nb_points=12)
self.metric = AveragePointwiseEuclideanMetric(feature)
else:
raise ValueError(f"Unknown metric: {metric}")
[docs]
@warning_for_keywords()
def cluster(self, streamlines, *, ordering=None):
"""Clusters `streamlines` into bundles.
Performs quickbundles algorithm using predefined metric and threshold.
Parameters
----------
streamlines : list of 2D arrays
Each 2D array represents a sequence of 3D points (points, 3).
ordering : iterable of indices, optional
Specifies the order in which data points will be clustered.
Returns
-------
`ClusterMapCentroid` object
Result of the clustering.
"""
from dipy.segment.clustering_algorithms import quickbundles
cluster_map = quickbundles(
streamlines,
self.metric,
threshold=self.threshold,
max_nb_clusters=self.max_nb_clusters,
ordering=ordering,
)
cluster_map.refdata = streamlines
return cluster_map
[docs]
class QuickBundlesX(Clustering):
r"""Clusters streamlines using QuickBundlesX.
See :footcite:p:`Garyfallidis2016` for further details about the method.
Parameters
----------
thresholds : list of float
Thresholds to use for each clustering layer. A threshold represents the
maximum distance from a cluster for a streamline to be still considered
as part of it.
metric : str or `Metric` object, optional
The distance metric to use when comparing two streamlines. By default,
the Minimum average Direct-Flip (MDF) distance
:footcite:p:`Garyfallidis2012a` is used and streamlines are
automatically resampled so they have 12 points.
References
----------
.. footbibliography::
"""
@warning_for_keywords()
def __init__(self, thresholds, *, metric="MDF_12points"):
self.thresholds = thresholds
if isinstance(metric, MinimumAverageDirectFlipMetric):
raise ValueError("Use AveragePointwiseEuclideanMetric instead")
if isinstance(metric, Metric):
self.metric = metric
elif metric == "MDF_12points":
feature = ResampleFeature(nb_points=12)
self.metric = AveragePointwiseEuclideanMetric(feature)
else:
raise ValueError(f"Unknown metric: {metric}")
[docs]
@warning_for_keywords()
def cluster(self, streamlines, *, ordering=None):
"""Clusters `streamlines` into bundles.
Performs QuickbundleX using a predefined metric and thresholds.
Parameters
----------
streamlines : list of 2D arrays
Each 2D array represents a sequence of 3D points (points, 3).
ordering : iterable of indices
Specifies the order in which data points will be clustered.
Returns
-------
`TreeClusterMap` object
Result of the clustering.
"""
from dipy.segment.clustering_algorithms import quickbundlesx
tree = quickbundlesx(
streamlines, self.metric, thresholds=self.thresholds, ordering=ordering
)
tree.refdata = streamlines
return tree
[docs]
class TreeCluster(ClusterCentroid):
@warning_for_keywords()
def __init__(self, threshold, centroid, *, indices=None):
super(TreeCluster, self).__init__(centroid=centroid, indices=indices)
self.threshold = threshold
self.parent = None
self.children = []
[docs]
def add(self, child):
child.parent = self
self.children.append(child)
@property
def is_leaf(self):
return len(self.children) == 0
[docs]
def return_indices(self):
return self.children
[docs]
class TreeClusterMap(ClusterMap):
def __init__(self, root):
self.root = root
self.leaves = []
def _retrieves_leaves(node):
if node.is_leaf:
self.leaves.append(node)
self.traverse_postorder(self.root, _retrieves_leaves)
@property
def refdata(self):
return self._refdata
@refdata.setter
def refdata(self, value):
if value is None:
value = Identity()
self._refdata = value
def _set_refdata(node):
node.refdata = self._refdata
self.traverse_postorder(self.root, _set_refdata)
[docs]
def traverse_postorder(self, node, visit):
for child in node.children:
self.traverse_postorder(child, visit)
visit(node)
[docs]
def iter_preorder(self, node):
parent_stack = []
while len(parent_stack) > 0 or node is not None:
if node is not None:
yield node
if len(node.children) > 0:
parent_stack += node.children[1:]
node = node.children[0]
else:
node = None
else:
node = parent_stack.pop()
def __iter__(self):
return self.iter_preorder(self.root)
[docs]
def get_clusters(self, wanted_level):
clusters = ClusterMapCentroid()
def _traverse(node, level=0):
if level == wanted_level:
clusters.add_cluster(node)
return
for child in node.children:
_traverse(child, level + 1)
_traverse(self.root)
return clusters
[docs]
@warning_for_keywords()
def qbx_and_merge(
streamlines, thresholds, *, nb_pts=20, select_randomly=None, rng=None, verbose=False
):
"""Run QuickBundlesX and then run again on the centroids of the last layer.
Running again QuickBundles at a layer has the effect of merging
some of the clusters that may be originally divided because of branching.
This function help obtain a result at a QuickBundles quality but with
QuickBundlesX speed. The merging phase has low cost because it is applied
only on the centroids rather than the entire dataset.
See :footcite:p:`Garyfallidis2012a` and :footcite:p:`Garyfallidis2016` for
further details about the method.
Parameters
----------
streamlines : Streamlines
Streamlines.
thresholds : sequence
List of distance thresholds for QuickBundlesX.
nb_pts : int
Number of points for discretizing each streamline
select_randomly : int
Randomly select a specific number of streamlines. If None all the
streamlines are used.
rng : numpy.random.Generator
If None then generator is initialized internally.
verbose : bool, optional.
If True, log information. Default False.
Returns
-------
clusters : obj
Contains the clusters of the last layer of QuickBundlesX after merging.
References
----------
.. footbibliography::
"""
t = time()
len_s = len(streamlines)
if select_randomly is None:
select_randomly = len_s
if rng is None:
rng = np.random.default_rng()
indices = rng.choice(len_s, min(select_randomly, len_s), replace=False)
sample_streamlines = set_number_of_points(streamlines, nb_points=nb_pts)
if verbose:
logger.info(f" Resampled to {nb_pts} points")
logger.info(f" Size is {nbytes(sample_streamlines):0.3f} MB")
logger.info(f" Duration of resampling is {time() - t:0.3f} s")
logger.info(" QBX phase starting...")
qbx = QuickBundlesX(thresholds, metric=AveragePointwiseEuclideanMetric())
t1 = time()
qbx_clusters = qbx.cluster(sample_streamlines, ordering=indices)
if verbose:
logger.info(" Merging phase starting ...")
qbx_merge = QuickBundlesX(
[thresholds[-1]], metric=AveragePointwiseEuclideanMetric()
)
final_level = len(thresholds)
len_qbx_fl = len(qbx_clusters.get_clusters(final_level))
qbx_ordering_final = rng.choice(len_qbx_fl, len_qbx_fl, replace=False)
qbx_merged_cluster_map = qbx_merge.cluster(
qbx_clusters.get_clusters(final_level).centroids, ordering=qbx_ordering_final
).get_clusters(1)
qbx_cluster_map = qbx_clusters.get_clusters(final_level)
merged_cluster_map = ClusterMapCentroid()
for cluster in qbx_merged_cluster_map:
merged_cluster = ClusterCentroid(centroid=cluster.centroid)
for i in cluster.indices:
merged_cluster.indices.extend(qbx_cluster_map[i].indices)
merged_cluster_map.add_cluster(merged_cluster)
merged_cluster_map.refdata = streamlines
if verbose:
logger.info(f" QuickBundlesX time for {select_randomly} random streamlines")
logger.info(f" Duration {time() - t1:0.3f} s\n")
return merged_cluster_map