import numpy as np
from dipy.segment.clustering import qbx_and_merge
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import length
from dipy.utils.optpkg import optional_package
fury, has_fury, setup_module = optional_package("fury", min_version="0.10.0")
if has_fury:
    from fury import actor
    from fury.lib import VTK_OBJECT, calldata_type
    from fury.shaders import add_shader_callback, shader_to_actor
[docs]
class ClustersVisualizer:
    @warning_for_keywords()
    def __init__(self, show_manager, scene, tractograms, *, enable_callbacks=True):
        # TODO: Avoid passing the entire show manager to the visualizer
        self.__show_man = show_manager
        self.__scene = scene
        self.__tractograms = tractograms
        self.__enable_callbacks = enable_callbacks
        self.__tractogram_clusters = {}
        self.__first_time = True
        self.__tractogram_colors = []
        self.__centroid_actors = {}
        self.__cluster_actors = {}
        self.__lengths = []
        self.__sizes = []
    def __apply_shader(self, dict_element):
        decl = """
            uniform float selected;
            """
        impl = """
            if (selected == 1)
            {
                fragOutput0 += vec4(.2, .2, .2, 0);
            }
            """
        shader_to_actor(dict_element["actor"], "fragment", decl_code=decl)
        shader_to_actor(
            dict_element["actor"], "fragment", impl_code=impl, block="light"
        )
        @calldata_type(VTK_OBJECT)
        def uniform_selected_callback(caller, event, calldata=None):
            program = calldata
            if program is not None:
                program.SetUniformf("selected", dict_element["selected"])
        add_shader_callback(
            dict_element["actor"], uniform_selected_callback, priority=100
        )
    def __left_click_centroid_callback(self, obj, event):
        self.__centroid_actors[obj]["selected"] = not self.__centroid_actors[obj][
            "selected"
        ]
        self.__cluster_actors[self.__centroid_actors[obj]["actor"]]["selected"] = (
            self.__centroid_actors[obj]["selected"]
        )
        # TODO: Find another way to rerender
        self.__show_man.render()
    def __left_click_cluster_callback(self, obj, event):
        if self.__cluster_actors[obj]["selected"]:
            self.__cluster_actors[obj]["actor"].VisibilityOn()
            ca = self.__cluster_actors[obj]["actor"]
            self.__centroid_actors[ca]["selected"] = 0
            obj.VisibilityOff()
            self.__centroid_actors[ca]["expanded"] = 0
        # TODO: Find another way to rerender
        self.__show_man.render()
[docs]
    def add_cluster_actors(self, tract_idx, streamlines, thr, colors):
        # Saving the tractogram colors in case of reclustering
        if self.__first_time:
            self.__tractogram_colors.append(colors)
        print(f"\nClustering threshold {thr}")
        clusters = qbx_and_merge(streamlines, [40, 30, 25, 20, thr])
        self.__tractogram_clusters[tract_idx] = clusters
        centroids = clusters.centroids
        print(f"Total number of centroids = {len(centroids)}")
        lengths = [length(c) for c in centroids]
        self.__lengths.extend(lengths)
        lengths = np.array(lengths)
        sizes = [len(c) for c in clusters]
        self.__sizes.extend(sizes)
        sizes = np.array(sizes)
        linewidths = np.interp(sizes, [np.min(sizes), np.max(sizes)], [0.1, 2.0])
        print(f"Minimum number of streamlines in cluster {np.min(sizes)}")
        print(f"Maximum number of streamlines in cluster {np.max(sizes)}")
        print("Building cluster actors\n")
        for idx, cent in enumerate(centroids):
            centroid_actor = actor.streamtube(
                [cent], colors=colors, linewidth=linewidths[idx], lod=False
            )
            self.__scene.add(centroid_actor)
            cluster_actor = actor.line(clusters[idx][:], lod=False)
            cluster_actor.GetProperty().SetRenderLinesAsTubes(1)
            cluster_actor.GetProperty().SetLineWidth(6)
            cluster_actor.GetProperty().SetOpacity(1)
            cluster_actor.VisibilityOff()
            self.__scene.add(cluster_actor)
            # Every centroid actor is paired to a cluster actor
            self.__centroid_actors[centroid_actor] = {
                "actor": cluster_actor,
                "cluster": idx,
                "tractogram": tract_idx,
                "size": sizes[idx],
                "length": lengths[idx],
                "selected": 0,
                "expanded": 0,
            }
            self.__cluster_actors[cluster_actor] = {
                "actor": centroid_actor,
                "cluster": idx,
                "tractogram": tract_idx,
                "size": sizes[idx],
                "length": lengths[idx],
                "selected": 0,
                "highlighted": 0,
            }
            self.__apply_shader(self.__centroid_actors[centroid_actor])
            self.__apply_shader(self.__cluster_actors[cluster_actor])
            if self.__enable_callbacks:
                centroid_actor.AddObserver(
                    "LeftButtonPressEvent", self.__left_click_centroid_callback, 1.0
                )
                cluster_actor.AddObserver(
                    "LeftButtonPressEvent", self.__left_click_cluster_callback, 1.0
                ) 
[docs]
    def recluster_tractograms(self, thr):
        for cent in self.__centroid_actors:
            self.__scene.rm(self.__centroid_actors[cent]["actor"])
        for clus in self.__cluster_actors:
            self.__scene.rm(self.__cluster_actors[clus]["actor"])
        self.__tractogram_clusters = {}
        self.__centroid_actors = {}
        self.__cluster_actors = {}
        self.__lengths = []
        self.__sizes = []
        # Keeping states of some attributes
        self.__first_time = False
        for t, sft in enumerate(self.__tractograms):
            streamlines = sft.streamlines
            self.add_cluster_actors(t, streamlines, thr, self.__tractogram_colors[t]) 
    @property
    def centroid_actors(self):
        return self.__centroid_actors
    @property
    def cluster_actors(self):
        return self.__cluster_actors
    @property
    def lengths(self):
        return np.array(self.__lengths)
    @property
    def sizes(self):
        return np.array(self.__sizes)
    @property
    def tractogram_clusters(self):
        return self.__tractogram_clusters