from abc import ABCMeta, abstractmethod
import logging
import operator
from time import time
import numpy as np
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metricspeed import (
    AveragePointwiseEuclideanMetric,
    Metric,
    MinimumAverageDirectFlipMetric,
)
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import nbytes, set_number_of_points
logger = logging.getLogger(__name__)
[docs]
class Identity:
    """Provides identity indexing functionality.
    This can replace any class supporting indexing used for referencing
    (e.g. list, tuple). Indexing an instance of this class will return the
    index provided instead of the element. It does not support slicing.
    """
    def __getitem__(self, idx):
        return idx 
[docs]
class Cluster:
    """Provides functionalities for interacting with a cluster.
    Useful container to retrieve index of elements grouped together. If
    a reference to the data is provided to `cluster_map`, elements will
    be returned instead of their index when possible.
    Parameters
    ----------
    cluster_map : `ClusterMap` object
        Reference to the set of clusters this cluster is being part of.
    id : int, optional
        Id of this cluster in its associated `cluster_map` object.
    refdata : list, optional
        Actual elements that clustered indices refer to.
    Notes
    -----
    A cluster does not contain actual data but instead knows how to
    retrieve them using its `ClusterMap` object.
    """
    @warning_for_keywords()
    def __init__(self, *, id=0, indices=None, refdata=None):
        if refdata is None:
            refdata = Identity()
        self.id = id
        self.refdata = refdata
        self.indices = indices if indices is not None else []
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        """Gets element(s) through indexing.
        If a reference to the data was provided (via refdata property)
        elements will be returned instead of their index.
        Parameters
        ----------
        idx : int, slice or list
            Index of the element(s) to get.
        Returns
        -------
        `Cluster` object(s)
            When `idx` is a int, returns a single element.
            When `idx` is either a slice or a list, returns a list of elements.
        """
        if isinstance(idx, (int, np.integer)):
            return self.refdata[self.indices[idx]]
        elif isinstance(idx, slice):
            return [self.refdata[i] for i in self.indices[idx]]
        elif isinstance(idx, list):
            return [self[i] for i in idx]
        msg = f"Index must be a int or a slice! Not '{type(idx)}'"
        raise TypeError(msg)
    def __iter__(self):
        return (self[i] for i in range(len(self)))
    def __str__(self):
        return "[" + ", ".join(map(str, self.indices)) + "]"
    def __repr__(self):
        return f"Cluster({str(self)})"
    def __eq__(self, other):
        return isinstance(other, Cluster) and self.indices == other.indices
    def __ne__(self, other):
        return not self == other
    def __cmp__(self, other):
        raise TypeError("Cannot compare Cluster objects.")
[docs]
    def assign(self, *indices):
        """Assigns indices to this cluster.
        Parameters
        ----------
        *indices : list of indices
            Indices to add to this cluster.
        """
        self.indices += indices 
 
[docs]
class ClusterCentroid(Cluster):
    """Provides functionalities for interacting with a cluster.
    Useful container to retrieve the indices of elements grouped together and
    the cluster's centroid. If a reference to the data is provided to
    `cluster_map`, elements will be returned instead of their index when
    possible.
    Parameters
    ----------
    cluster_map : `ClusterMapCentroid` object
        Reference to the set of clusters this cluster is being part of.
    id : int, optional
        Id of this cluster in its associated `cluster_map` object.
    refdata : list, optional
        Actual elements that clustered indices refer to.
    Notes
    -----
    A cluster does not contain actual data but instead knows how to
    retrieve them using its `ClusterMapCentroid` object.
    """
    @warning_for_keywords()
    def __init__(self, centroid, *, id=0, indices=None, refdata=None):
        if refdata is None:
            refdata = Identity()
        super(ClusterCentroid, self).__init__(id=id, indices=indices, refdata=refdata)
        self.centroid = centroid.copy()
        self.new_centroid = centroid.copy()
    def __eq__(self, other):
        return (
            isinstance(other, ClusterCentroid)
            and np.all(self.centroid == other.centroid)
            and super(ClusterCentroid, self).__eq__(other)
        )
[docs]
    def assign(self, id_datum, features):
        """Assigns a data point to this cluster.
        Parameters
        ----------
        id_datum : int
            Index of the data point to add to this cluster.
        features : 2D array
            Data point's features to modify this cluster's centroid.
        """
        N = len(self)
        self.new_centroid = ((self.new_centroid * N) + features) / (N + 1.0)
        super(ClusterCentroid, self).assign(id_datum) 
[docs]
    def update(self):
        """Update centroid of this cluster.
        Returns
        -------
        converged : bool
            Tells if the centroid has moved.
        """
        converged = np.equal(self.centroid, self.new_centroid)
        self.centroid = self.new_centroid.copy()
        return converged 
 
[docs]
class ClusterMap:
    """Provides functionalities for interacting with clustering outputs.
    Useful container to create, remove, retrieve and filter clusters.
    If `refdata` is given, elements will be returned instead of their
    index when using `Cluster` objects.
    Parameters
    ----------
    refdata : list
        Actual elements that clustered indices refer to.
    """
    @warning_for_keywords()
    def __init__(self, *, refdata=None):
        if refdata is None:
            refdata = Identity()
        self._clusters = []
        self.refdata = refdata
    @property
    def clusters(self):
        return self._clusters
    @property
    def refdata(self):
        return self._refdata
    @refdata.setter
    def refdata(self, value):
        if value is None:
            value = Identity()
        self._refdata = value
        for cluster in self.clusters:
            cluster.refdata = self._refdata
    def __len__(self):
        return len(self.clusters)
    def __getitem__(self, idx):
        """Gets cluster(s) through indexing.
        Parameters
        ----------
        idx : int, slice, list or boolean array
            Index of the element(s) to get.
        Returns
        -------
        `Cluster` object(s)
            When `idx` is an int, returns a single `Cluster` object.
            When `idx`is either a slice, list or boolean array, returns
            a list of `Cluster` objects.
        """
        if isinstance(idx, np.ndarray) and idx.dtype == bool:
            return [self.clusters[i] for i, take_it in enumerate(idx) if take_it]
        elif isinstance(idx, slice):
            return [self.clusters[i] for i in range(*idx.indices(len(self)))]
        elif isinstance(idx, list):
            return [self.clusters[i] for i in idx]
        return self.clusters[idx]
    def __iter__(self):
        return iter(self.clusters)
    def __str__(self):
        return "[" + ", ".join(map(str, self)) + "]"
    def __repr__(self):
        return f"ClusterMap({str(self)})"
    def _richcmp(self, other, op):
        """Compares this cluster map with another cluster map or an integer.
        Two `ClusterMap` objects are equal if they contain the same clusters.
        When comparing a `ClusterMap` object with an integer, the comparison
        will be performed on the size of the clusters instead.
        Parameters
        ----------
        other : `ClusterMap` object or int
            Object to compare to.
        op : rich comparison operators (see module `operator`)
            Valid operators are: lt, le, eq, ne, gt or ge.
        Returns
        -------
        bool or 1D array (bool)
            When comparing to another `ClusterMap` object, it returns whether
            the two `ClusterMap` objects contain the same clusters or not.
            When comparing to an integer the comparison is performed on the
            clusters sizes, it returns an array of boolean.
        """
        if isinstance(other, ClusterMap):
            if op is operator.eq:
                return (
                    isinstance(other, ClusterMap)
                    and len(self) == len(other)
                    and self.clusters == other.clusters
                )
            elif op is operator.ne:
                return not self == other
            raise NotImplementedError(
                "Can only check if two ClusterMap instances are equal or not."
            )
        elif isinstance(other, int):
            return np.array([op(len(cluster), other) for cluster in self])
        msg = (
            "ClusterMap only supports comparison with a int or another"
            " instance of Clustermap."
        )
        raise NotImplementedError(msg)
    def __eq__(self, other):
        return self._richcmp(other, operator.eq)
    def __ne__(self, other):
        return self._richcmp(other, operator.ne)
    def __lt__(self, other):
        return self._richcmp(other, operator.lt)
    def __le__(self, other):
        return self._richcmp(other, operator.le)
    def __gt__(self, other):
        return self._richcmp(other, operator.gt)
    def __ge__(self, other):
        return self._richcmp(other, operator.ge)
[docs]
    def add_cluster(self, *clusters):
        """Adds one or multiple clusters to this cluster map.
        Parameters
        ----------
        *clusters : `Cluster` object, ...
            Cluster(s) to be added in this cluster map.
        """
        for cluster in clusters:
            self.clusters.append(cluster)
            cluster.refdata = self.refdata 
[docs]
    def remove_cluster(self, *clusters):
        """Remove one or multiple clusters from this cluster map.
        Parameters
        ----------
        *clusters : `Cluster` object, ...
            Cluster(s) to be removed from this cluster map.
        """
        for cluster in clusters:
            self.clusters.remove(cluster) 
[docs]
    def clear(self):
        """Remove all clusters from this cluster map."""
        del self.clusters[:] 
[docs]
    def size(self):
        """Gets number of clusters contained in this cluster map."""
        return len(self) 
[docs]
    def clusters_sizes(self):
        """Gets the size of every cluster contained in this cluster map.
        Returns
        -------
        list of int
            Sizes of every cluster in this cluster map.
        """
        return list(map(len, self)) 
[docs]
    def get_large_clusters(self, min_size):
        """Gets clusters which contains at least `min_size` elements.
        Parameters
        ----------
        min_size : int
            Minimum number of elements a cluster needs to have to be selected.
        Returns
        -------
        list of `Cluster` objects
            Clusters having at least `min_size` elements.
        """
        return self[self >= min_size] 
[docs]
    def get_small_clusters(self, max_size):
        """Gets clusters which contains at most `max_size` elements.
        Parameters
        ----------
        max_size : int
            Maximum number of elements a cluster can have to be selected.
        Returns
        -------
        list of `Cluster` objects
            Clusters having at most `max_size` elements.
        """
        return self[self <= max_size] 
 
[docs]
class ClusterMapCentroid(ClusterMap):
    """Provides functionalities for interacting with clustering outputs
    that have centroids.
    Allows to retrieve easily the centroid of every cluster. Also, it is
    a useful container to create, remove, retrieve and filter clusters.
    If `refdata` is given, elements will be returned instead of their
    index when using `ClusterCentroid` objects.
    Parameters
    ----------
    refdata : list
        Actual elements that clustered indices refer to.
    """
    @property
    def centroids(self):
        return [cluster.centroid for cluster in self.clusters] 
[docs]
class Clustering:
    __metaclass__ = ABCMeta
[docs]
    @abstractmethod
    @warning_for_keywords()
    def cluster(self, data, *, ordering=None):
        """Clusters `data`.
        Subclasses will perform their clustering algorithm here.
        Parameters
        ----------
        data : list of N-dimensional arrays
            Each array represents a data point.
        ordering : iterable of indices, optional
            Specifies the order in which data points will be clustered.
        Returns
        -------
        `ClusterMap` object
            Result of the clustering.
        """
        msg = "Subclass has to define method 'cluster(data, ordering)'!"
        raise NotImplementedError(msg) 
 
[docs]
class QuickBundles(Clustering):
    r"""Clusters streamlines using QuickBundles.
    Given a list of streamlines, the QuickBundles algorithm
    :footcite:p:`Garyfallidis2012a` sequentially assigns each streamline to its
    closest bundle in $\mathcal{O}(Nk)$ where $N$ is the number of streamlines
    and $k$ is the final number of bundles. If for a given streamline its
    closest bundle is farther than `threshold`, a new bundle is created and the
    streamline is assigned to it except if the number of bundles has already
    exceeded `max_nb_clusters`.
    Parameters
    ----------
    threshold : float
        The maximum distance from a bundle for a streamline to be still
        considered as part of it.
    metric : str or `Metric` object, optional
        The distance metric to use when comparing two streamlines. By default,
        the Minimum average Direct-Flip (MDF) distance
        :footcite:p:`Garyfallidis2012a` is used and streamlines are
        automatically resampled so they have 12 points.
    max_nb_clusters : int, optional
        Limits the creation of bundles.
    Examples
    --------
    >>> from dipy.segment.clustering import QuickBundles
    >>> from dipy.data import get_fnames
    >>> from dipy.io.streamline import load_tractogram
    >>> from dipy.tracking.streamline import Streamlines
    >>> fname = get_fnames(name='fornix')
    >>> fornix = load_tractogram(fname, 'same',
    ...                          bbox_valid_check=False).streamlines
    >>> streamlines = Streamlines(fornix)
    >>> # Segment fornix with a threshold of 10mm and streamlines resampled
    >>> # to 12 points.
    >>> qb = QuickBundles(threshold=10.)
    >>> clusters = qb.cluster(streamlines)
    >>> len(clusters)
    4
    >>> list(map(len, clusters))
    [61, 191, 47, 1]
    >>> # Resampling streamlines differently is done explicitly as follows.
    >>> # Note this has an impact on the speed and the accuracy (tradeoff).
    >>> from dipy.segment.featurespeed import ResampleFeature
    >>> from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
    >>> feature = ResampleFeature(nb_points=2)
    >>> metric = AveragePointwiseEuclideanMetric(feature)
    >>> qb = QuickBundles(threshold=10., metric=metric)
    >>> clusters = qb.cluster(streamlines)
    >>> len(clusters)
    4
    >>> list(map(len, clusters))
    [58, 142, 72, 28]
    References
    ----------
    .. footbibliography::
    """
    @warning_for_keywords()
    def __init__(self, threshold, *, metric="MDF_12points", max_nb_clusters=None):
        if max_nb_clusters is None:
            max_nb_clusters = np.iinfo("i4").max
        self.threshold = threshold
        self.max_nb_clusters = max_nb_clusters
        if isinstance(metric, MinimumAverageDirectFlipMetric):
            raise ValueError("Use AveragePointwiseEuclideanMetric instead")
        if isinstance(metric, Metric):
            self.metric = metric
        elif metric == "MDF_12points":
            feature = ResampleFeature(nb_points=12)
            self.metric = AveragePointwiseEuclideanMetric(feature)
        else:
            raise ValueError(f"Unknown metric: {metric}")
[docs]
    @warning_for_keywords()
    def cluster(self, streamlines, *, ordering=None):
        """Clusters `streamlines` into bundles.
        Performs quickbundles algorithm using predefined metric and threshold.
        Parameters
        ----------
        streamlines : list of 2D arrays
            Each 2D array represents a sequence of 3D points (points, 3).
        ordering : iterable of indices, optional
            Specifies the order in which data points will be clustered.
        Returns
        -------
        `ClusterMapCentroid` object
            Result of the clustering.
        """
        from dipy.segment.clustering_algorithms import quickbundles
        cluster_map = quickbundles(
            streamlines,
            self.metric,
            threshold=self.threshold,
            max_nb_clusters=self.max_nb_clusters,
            ordering=ordering,
        )
        cluster_map.refdata = streamlines
        return cluster_map 
 
[docs]
class QuickBundlesX(Clustering):
    r"""Clusters streamlines using QuickBundlesX.
    See :footcite:p:`Garyfallidis2016` for further details about the method.
    Parameters
    ----------
    thresholds : list of float
        Thresholds to use for each clustering layer. A threshold represents the
        maximum distance from a cluster for a streamline to be still considered
        as part of it.
    metric : str or `Metric` object, optional
        The distance metric to use when comparing two streamlines. By default,
        the Minimum average Direct-Flip (MDF) distance
        :footcite:p:`Garyfallidis2012a` is used and streamlines are
        automatically resampled so they have 12 points.
    References
    ----------
    .. footbibliography::
    """
    @warning_for_keywords()
    def __init__(self, thresholds, *, metric="MDF_12points"):
        self.thresholds = thresholds
        if isinstance(metric, MinimumAverageDirectFlipMetric):
            raise ValueError("Use AveragePointwiseEuclideanMetric instead")
        if isinstance(metric, Metric):
            self.metric = metric
        elif metric == "MDF_12points":
            feature = ResampleFeature(nb_points=12)
            self.metric = AveragePointwiseEuclideanMetric(feature)
        else:
            raise ValueError(f"Unknown metric: {metric}")
[docs]
    @warning_for_keywords()
    def cluster(self, streamlines, *, ordering=None):
        """Clusters `streamlines` into bundles.
        Performs QuickbundleX using a predefined metric and thresholds.
        Parameters
        ----------
        streamlines : list of 2D arrays
            Each 2D array represents a sequence of 3D points (points, 3).
        ordering : iterable of indices
            Specifies the order in which data points will be clustered.
        Returns
        -------
        `TreeClusterMap` object
            Result of the clustering.
        """
        from dipy.segment.clustering_algorithms import quickbundlesx
        tree = quickbundlesx(
            streamlines, self.metric, thresholds=self.thresholds, ordering=ordering
        )
        tree.refdata = streamlines
        return tree 
 
[docs]
class TreeCluster(ClusterCentroid):
    @warning_for_keywords()
    def __init__(self, threshold, centroid, *, indices=None):
        super(TreeCluster, self).__init__(centroid=centroid, indices=indices)
        self.threshold = threshold
        self.parent = None
        self.children = []
[docs]
    def add(self, child):
        child.parent = self
        self.children.append(child) 
    @property
    def is_leaf(self):
        return len(self.children) == 0
[docs]
    def return_indices(self):
        return self.children 
 
[docs]
class TreeClusterMap(ClusterMap):
    def __init__(self, root):
        self.root = root
        self.leaves = []
        def _retrieves_leaves(node):
            if node.is_leaf:
                self.leaves.append(node)
        self.traverse_postorder(self.root, _retrieves_leaves)
    @property
    def refdata(self):
        return self._refdata
    @refdata.setter
    def refdata(self, value):
        if value is None:
            value = Identity()
        self._refdata = value
        def _set_refdata(node):
            node.refdata = self._refdata
        self.traverse_postorder(self.root, _set_refdata)
[docs]
    def traverse_postorder(self, node, visit):
        for child in node.children:
            self.traverse_postorder(child, visit)
        visit(node) 
[docs]
    def iter_preorder(self, node):
        parent_stack = []
        while len(parent_stack) > 0 or node is not None:
            if node is not None:
                yield node
                if len(node.children) > 0:
                    parent_stack += node.children[1:]
                    node = node.children[0]
                else:
                    node = None
            else:
                node = parent_stack.pop() 
    def __iter__(self):
        return self.iter_preorder(self.root)
[docs]
    def get_clusters(self, wanted_level):
        clusters = ClusterMapCentroid()
        def _traverse(node, level=0):
            if level == wanted_level:
                clusters.add_cluster(node)
                return
            for child in node.children:
                _traverse(child, level + 1)
        _traverse(self.root)
        return clusters 
 
[docs]
@warning_for_keywords()
def qbx_and_merge(
    streamlines, thresholds, *, nb_pts=20, select_randomly=None, rng=None, verbose=False
):
    """Run QuickBundlesX and then run again on the centroids of the last layer.
    Running again QuickBundles at a layer has the effect of merging
    some of the clusters that may be originally divided because of branching.
    This function help obtain a result at a QuickBundles quality but with
    QuickBundlesX speed. The merging phase has low cost because it is applied
    only on the centroids rather than the entire dataset.
    See :footcite:p:`Garyfallidis2012a` and :footcite:p:`Garyfallidis2016` for
    further details about the method.
    Parameters
    ----------
    streamlines : Streamlines
        Streamlines.
    thresholds : sequence
        List of distance thresholds for QuickBundlesX.
    nb_pts : int
        Number of points for discretizing each streamline
    select_randomly : int
        Randomly select a specific number of streamlines. If None all the
        streamlines are used.
    rng : numpy.random.Generator
        If None then generator is initialized internally.
    verbose : bool, optional.
        If True, log information. Default False.
    Returns
    -------
    clusters : obj
        Contains the clusters of the last layer of QuickBundlesX after merging.
    References
    ----------
    .. footbibliography::
    """
    t = time()
    len_s = len(streamlines)
    if select_randomly is None:
        select_randomly = len_s
    if rng is None:
        rng = np.random.default_rng()
    indices = rng.choice(len_s, min(select_randomly, len_s), replace=False)
    sample_streamlines = set_number_of_points(streamlines, nb_points=nb_pts)
    if verbose:
        logger.info(f" Resampled to {nb_pts} points")
        logger.info(f" Size is {nbytes(sample_streamlines):0.3f} MB")
        logger.info(f" Duration of resampling is {time() - t:0.3f} s")
        logger.info(" QBX phase starting...")
    qbx = QuickBundlesX(thresholds, metric=AveragePointwiseEuclideanMetric())
    t1 = time()
    qbx_clusters = qbx.cluster(sample_streamlines, ordering=indices)
    if verbose:
        logger.info(" Merging phase starting ...")
    qbx_merge = QuickBundlesX(
        [thresholds[-1]], metric=AveragePointwiseEuclideanMetric()
    )
    final_level = len(thresholds)
    len_qbx_fl = len(qbx_clusters.get_clusters(final_level))
    qbx_ordering_final = rng.choice(len_qbx_fl, len_qbx_fl, replace=False)
    qbx_merged_cluster_map = qbx_merge.cluster(
        qbx_clusters.get_clusters(final_level).centroids, ordering=qbx_ordering_final
    ).get_clusters(1)
    qbx_cluster_map = qbx_clusters.get_clusters(final_level)
    merged_cluster_map = ClusterMapCentroid()
    for cluster in qbx_merged_cluster_map:
        merged_cluster = ClusterCentroid(centroid=cluster.centroid)
        for i in cluster.indices:
            merged_cluster.indices.extend(qbx_cluster_map[i].indices)
        merged_cluster_map.add_cluster(merged_cluster)
    merged_cluster_map.refdata = streamlines
    if verbose:
        logger.info(f" QuickBundlesX time for {select_randomly} random streamlines")
        logger.info(f" Duration {time() - t1:0.3f} s\n")
    return merged_cluster_map