import numpy as np
from dipy.viz.horizon.tab import HorizonTab, build_slider
[docs]
class ClustersTab(HorizonTab):
    def __init__(self, clusters_visualizer, threshold):
        """Initialize Interaction tab for cluster visualization.
        Parameters
        ----------
        clusters_visualizer : ClusterVisualizer
        threshold : float
        """
        super().__init__()
        self._visualizer = clusters_visualizer
        self._name = "Clusters"
        self._tab_id = 0
        sizes = self._visualizer.sizes
        self._size_slider_label, self._size_slider = build_slider(
            initial_value=np.percentile(sizes, 50),
            min_value=np.min(sizes),
            max_value=np.percentile(sizes, 98),
            text_template="{value:.0f}",
            label="Size",
            on_change=self._change_size,
        )
        lengths = self._visualizer.lengths
        self._length_slider_label, self._length_slider = build_slider(
            initial_value=np.percentile(lengths, 25),
            min_value=np.min(lengths),
            max_value=np.percentile(lengths, 98),
            text_template="{value:.0f}",
            label="Length",
            on_change=self._change_length,
        )
        self._threshold_slider_label, self._threshold_slider = build_slider(
            initial_value=threshold,
            min_value=5,
            max_value=25,
            text_template="{value:.0f}",
            label="Threshold",
            on_handle_released=self._change_threshold,
        )
        self._register_elements(
            self._size_slider_label,
            self._size_slider,
            self._length_slider_label,
            self._length_slider,
            self._threshold_slider_label,
            self._threshold_slider,
        )
    def _change_length(self, slider):
        """Change the length threshold for visibility.
        Parameters
        ----------
        slider : LineSlider2D
            FURY object for slider.
        """
        self._length_slider.selected_value = int(np.rint(slider.value))
        self._update_clusters()
    def _change_size(self, slider):
        """Change the size threshold for visibility.
        Parameters
        ----------
        slider : LineSlider2D
            FURY object for slider.
        """
        self._size_slider.selected_value = int(np.rint(slider.value))
        self._update_clusters()
    def _change_threshold(self, _istyle, _obj, slider):
        """Re-cluster the tractograms according to the new threshold set. It
        will also update the size and length slider for corresponding
        threshold.
        Parameters
        ----------
        _istyle : vtkInteractor
            Should not be used.
        _obj : vtkObject
            Should not be used.
        slider : LineSlider2D
            FURY object for slider.
        """
        value = int(np.rint(slider.value))
        if value != self._threshold_slider.selected_value:
            self._visualizer.recluster_tractograms(value)
            sizes = self._visualizer.sizes
            self._size_slider.selected_value = np.percentile(sizes, 50)
            lengths = self._visualizer.lengths
            self._length_slider.selected_value = np.percentile(lengths, 25)
            self._update_clusters()
            self._size_slider.obj.min_value = np.min(sizes)
            self._size_slider.obj.max_value = np.percentile(sizes, 98)
            self._size_slider.obj.value = self._size_slider.selected_value
            self._size_slider.obj.update()
            self._length_slider.obj.min_value = np.min(lengths)
            self._length_slider.obj.max_value = np.percentile(lengths, 98)
            self._length_slider.obj.value = self._length_slider.selected_value
            self._length_slider.obj.update()
            self._threshold_slider.selected_value = value
    def _update_clusters(self):
        """Updates the clusters according to set size and length."""
        for k, cluster in self.cluster_actors.items():
            length_validation = cluster["length"] < self._length_slider.selected_value
            size_validation = cluster["size"] < self._size_slider.selected_value
            if length_validation or size_validation:
                cluster["actor"].SetVisibility(False)
                if k.GetVisibility():
                    k.SetVisibility(False)
            else:
                cluster["actor"].SetVisibility(True)
[docs]
    def build(self, tab_id):
        """Position the elements in the tab.
        Parameters
        ----------
        tab_id : int
            Id of the tab.
        """
        self._tab_id = tab_id
        x_pos = 0.02
        self._size_slider_label.position = (x_pos, 0.85)
        self._length_slider_label.position = (x_pos, 0.62)
        self._threshold_slider_label.position = (x_pos, 0.38)
        x_pos = 0.12
        self._size_slider.position = (x_pos, 0.85)
        self._length_slider.position = (x_pos, 0.62)
        self._threshold_slider.position = (x_pos, 0.38) 
    @property
    def name(self):
        """Title of the tab.
        Returns
        -------
        str
        """
        return self._name
    @property
    def cluster_actors(self):
        """Cluster actors of the tractograms.
        Returns
        -------
        dict
            various properties of clusters.
        """
        return self._visualizer.cluster_actors
    @property
    def centroid_actors(self):
        """Centroid actors of the tractograms.
        Returns
        -------
        dict
            various properties of centroids.
        """
        return self._visualizer.centroid_actors
    @property
    def actors(self):
        """All the actors in the visualizer.
        Returns
        -------
        list
        """
        actors = []
        for cluster_actor in self.cluster_actors.values():
            actors.append(cluster_actor["actor"])
        for centroid_actor in self.centroid_actors.values():
            actors.append(centroid_actor["actor"])
        return actors