Source code for dipy.viz.horizon.app

from warnings import warn

import numpy as np
from packaging.version import Version

from dipy import __version__ as horizon_version
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import save_tractogram
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import Streamlines
from dipy.utils.optpkg import optional_package
from dipy.viz.gmem import GlobalHorizon
from dipy.viz.horizon.tab import (
    ClustersTab,
    PeaksTab,
    ROIsTab,
    SlicesTab,
    SurfaceTab,
    TabManager,
    build_label,
)
from dipy.viz.horizon.util import (
    check_img_dtype,
    check_img_shapes,
    check_peak_size,
    unpack_image,
    unpack_surface,
)
from dipy.viz.horizon.visualizer import (
    ClustersVisualizer,
    PeaksVisualizer,
    SlicesVisualizer,
    SurfaceVisualizer,
)

fury, has_fury, setup_module = optional_package("fury", min_version="0.10.0")

if has_fury:
    from fury import __version__ as fury_version, actor, ui, window
    from fury.colormap import distinguishable_colormap


# TODO: Re-enable >> right click: see menu
HELP_MESSAGE = """
>> left click: select centroid
>> e: expand centroids
>> r: collapse all clusters
>> h: hide unselected centroids
>> i: invert selection
>> a: select all centroids
>> s: save in file
>> y: new window
>> o: hide/show this panel
"""


[docs] class Horizon: @warning_for_keywords() def __init__( self, *, tractograms=None, images=None, pams=None, surfaces=None, cluster=False, rgb=False, cluster_thr=15.0, random_colors=None, length_gt=0, length_lt=1000, clusters_gt=0, clusters_lt=10000, world_coords=True, interactive=True, out_png="tmp.png", recorded_events=None, return_showm=False, bg_color=(0, 0, 0), order_transparent=True, buan=False, buan_colors=None, roi_images=False, roi_colors=(1, 0, 0), surface_colors=((1, 0, 0),), ): """Interactive medical visualization - Invert the Horizon! :footcite:p:`Garyfallidis2019`. Parameters ---------- tractograms : sequence of StatefulTractograms StatefulTractograms are used for making sure that the coordinate systems are correct images : sequence of tuples Each tuple contains data and affine pams : sequence of PeakAndMetrics Contains peak directions and spherical harmonic coefficients surfaces : sequence of tuples Each tuple contains vertices and faces cluster : bool Enable QuickBundlesX clustering rgb : bool, optional Enable the color image (rgb only, alpha channel will be ignored). cluster_thr : float Distance threshold used for clustering. Default value 15.0 for small animal data you may need to use something smaller such as 2.0. The threshold is in mm. For this parameter to be active ``cluster`` should be enabled. random_colors : string, optional Given multiple tractograms and/or ROIs then each tractogram and/or ROI will be shown with a different color. If no value is provided, both the tractograms and the ROIs will have a different random color generated from a distinguishable colormap. If the effect should only be applied to one of the 2 types, then use the options 'tracts' and 'rois' for the tractograms and the ROIs respectively. length_gt : float Clusters with average length greater than ``length_gt`` amount in mm will be shown. length_lt : float Clusters with average length less than ``length_lt`` amount in mm will be shown. clusters_gt : int Clusters with size greater than ``clusters_gt`` will be shown. clusters_lt : int Clusters with size less than ``clusters_lt`` will be shown. world_coords : bool Show data in their world coordinates (not native voxel coordinates) Default True. interactive : bool Allow user interaction. If False then Horizon goes on stealth mode and just saves pictures. out_png : string Filename of saved picture. recorded_events : string File path to replay recorded events return_showm : bool Return ShowManager object. Used only at Python level. Can be used for extending Horizon's capabilities externally and for testing purposes. bg_color : ndarray or list or tuple Define the background color of the scene. Default is black (0, 0, 0) order_transparent : bool Default True. Use depth peeling to sort transparent objects. If True also enables anti-aliasing. buan : bool, optional Enables BUAN framework visualization. Default is False. buan_colors : list, optional List of colors for bundles. roi_images : bool, optional Displays binary images as contours. Default is False. roi_colors : ndarray or list or tuple, optional Define the colors of the roi images. Default is red (1, 0, 0) References ---------- .. footbibliography:: """ if not has_fury: raise ImportError( "Horizon requires FURY. Please install it with pip install fury" ) if Version(fury_version) < Version("0.10.0"): ValueError( "Horizon requires FURY version 0.10.0 or higher." " Please upgrade FURY with pip install -U fury." ) self.cluster = cluster self.rgb = rgb self.cluster_thr = cluster_thr self.random_colors = random_colors self.length_lt = length_lt self.length_gt = length_gt self.clusters_lt = clusters_lt self.clusters_gt = clusters_gt self.world_coords = world_coords self.interactive = interactive self.prng = np.random.RandomState(27) self.tractograms = tractograms or [] self.out_png = out_png self.images = images or [] self.pams = pams or [] self._surfaces = surfaces or [] self.cea = {} # holds centroid actors self.cla = {} # holds cluster actors self.recorded_events = recorded_events self.show_m = None self._scene = None self.return_showm = return_showm self.bg_color = bg_color self.order_transparent = order_transparent self.buan = buan self.buan_colors = buan_colors self.__roi_images = roi_images self.__roi_colors = roi_colors self._surface_colors = surface_colors self.color_gen = distinguishable_colormap() if self.random_colors is not None: if not self.random_colors: self.random_colors = ["tracts", "rois"] else: self.random_colors = [] self.__clusters_visualizer = None self.__tabs = [] self.__tab_mgr = None self.__help_visible = True # TODO: Move to another class/module self.__hide_centroids = True self.__select_all = False self.__win_size = (0, 0) # TODO: Move to another class/module def __expand(self): centroid_actors = self.__clusters_visualizer.centroid_actors lengths = self.__clusters_visualizer.lengths sizes = self.__clusters_visualizer.sizes min_length = np.min(lengths) min_size = np.min(sizes) for cent in centroid_actors: if centroid_actors[cent]["selected"]: if not centroid_actors[cent]["expanded"]: len_ = centroid_actors[cent]["length"] sz_ = centroid_actors[cent]["size"] if len_ >= min_length and sz_ >= min_size: centroid_actors[cent]["actor"].VisibilityOn() cent.VisibilityOff() centroid_actors[cent]["expanded"] = 1 self.show_m.render() # TODO: Move to another class/module def __hide(self): centroid_actors = self.__clusters_visualizer.centroid_actors lengths = self.__clusters_visualizer.lengths sizes = self.__clusters_visualizer.sizes min_length = np.min(lengths) min_size = np.min(sizes) for cent in centroid_actors: valid_length = centroid_actors[cent]["length"] >= min_length valid_size = centroid_actors[cent]["size"] >= min_size if self.__hide_centroids: if valid_length or valid_size: if centroid_actors[cent]["selected"] == 0: cent.VisibilityOff() else: if valid_length and valid_size: if centroid_actors[cent]["selected"] == 0: cent.VisibilityOn() self.__hide_centroids = not self.__hide_centroids self.show_m.render() # TODO: Move to another class/module def __invert(self): centroid_actors = self.__clusters_visualizer.centroid_actors cluster_actors = self.__clusters_visualizer.cluster_actors lengths = self.__clusters_visualizer.lengths sizes = self.__clusters_visualizer.sizes min_length = np.min(lengths) min_size = np.min(sizes) for cent in centroid_actors: valid_length = centroid_actors[cent]["length"] >= min_length valid_size = centroid_actors[cent]["size"] >= min_size if valid_length and valid_size: centroid_actors[cent]["selected"] = not centroid_actors[cent][ "selected" ] clus = centroid_actors[cent]["actor"] cluster_actors[clus]["selected"] = centroid_actors[cent]["selected"] self.show_m.render() def __key_press_events(self, obj, event): key = obj.GetKeySym() # TODO: Move to another class/module if self.cluster: # retract help panel if key in ("o", "O"): panel_size = self.help_panel._get_size() if self.__help_visible: new_pos = np.array(self.__win_size) - 10 self.__help_visible = False else: new_pos = np.array(self.__win_size) - panel_size - 5 self.__help_visible = True self.help_panel._set_position(new_pos) self.show_m.render() if key in ("a", "A"): self.__show_all() if key in ("e", "E"): self.__expand() # hide on/off unselected centroids if key in ("h", "H"): self.__hide() # invert selection if key in ("i", "I"): self.__invert() if key in ("r", "R"): self.__reset() # save current result if key in ("s", "S"): self.__save() if key in ("y", "Y"): self.__new_window() # TODO: Move to another class/module def __new_window(self): cluster_actors = self.__clusters_visualizer.cluster_actors tractogram_clusters = self.__clusters_visualizer.tractogram_clusters active_streamlines = Streamlines() for bundle in cluster_actors.keys(): if bundle.GetVisibility(): t = cluster_actors[bundle]["tractogram"] c = cluster_actors[bundle]["cluster"] indices = tractogram_clusters[t][c] active_streamlines.extend(Streamlines(indices)) # Using the header of the first of the tractograms active_sft = StatefulTractogram( active_streamlines, self.tractograms[0], Space.RASMM ) hz2 = Horizon( tractograms=[active_sft], images=self.images, cluster=True, cluster_thr=self.cluster_thr / 2.0, random_colors=self.random_colors, length_lt=np.inf, length_gt=0, clusters_lt=np.inf, clusters_gt=0, world_coords=True, interactive=True, ) ren2 = hz2.build_scene() hz2.build_show(ren2) # TODO: Move to another class/module def __reset(self): centroid_actors = self.__clusters_visualizer.centroid_actors lengths = self.__clusters_visualizer.lengths sizes = self.__clusters_visualizer.sizes min_length = np.min(lengths) min_size = np.min(sizes) for cent in centroid_actors: valid_length = centroid_actors[cent]["length"] >= min_length valid_size = centroid_actors[cent]["size"] >= min_size if valid_length and valid_size: centroid_actors[cent]["actor"].VisibilityOff() cent.VisibilityOn() centroid_actors[cent]["expanded"] = 0 self.show_m.render() # TODO: Move to another class/module def __save(self): cluster_actors = self.__clusters_visualizer.cluster_actors tractogram_clusters = self.__clusters_visualizer.tractogram_clusters saving_streamlines = Streamlines() for bundle in cluster_actors.keys(): if bundle.GetVisibility(): t = cluster_actors[bundle]["tractogram"] c = cluster_actors[bundle]["cluster"] indices = tractogram_clusters[t][c] saving_streamlines.extend(Streamlines(indices)) print("Saving result in tmp.trk") # Using the header of the first of the tractograms sft_new = StatefulTractogram( saving_streamlines, self.tractograms[0], Space.RASMM ) save_tractogram(sft_new, "tmp.trk", bbox_valid_check=False) print("Saved!") # TODO: Move to another class/module def __show_all(self): centroid_actors = self.__clusters_visualizer.centroid_actors cluster_actors = self.__clusters_visualizer.cluster_actors lengths = self.__clusters_visualizer.lengths sizes = self.__clusters_visualizer.sizes min_length = np.min(lengths) min_size = np.min(sizes) if self.__select_all: for cent in centroid_actors: valid_length = centroid_actors[cent]["length"] >= min_length valid_size = centroid_actors[cent]["size"] >= min_size if valid_length and valid_size: centroid_actors[cent]["selected"] = 0 clus = centroid_actors[cent]["actor"] cluster_actors[clus]["selected"] = centroid_actors[cent]["selected"] self.__select_all = False else: for cent in centroid_actors: valid_length = centroid_actors[cent]["length"] >= min_length valid_size = centroid_actors[cent]["size"] >= min_size if valid_length and valid_size: centroid_actors[cent]["selected"] = 1 clus = centroid_actors[cent]["actor"] cluster_actors[clus]["selected"] = centroid_actors[cent]["selected"] self.__select_all = True self.show_m.render() def __win_callback(self, obj, event): if self.__win_size != obj.GetSize(): self.__win_size = obj.GetSize() if len(self.__tabs) > 0: self.__tab_mgr.reposition(self.__win_size) if self.cluster: if self.__help_visible: panel_size = self.help_panel._get_size() new_pos = np.array(self.__win_size) - panel_size - 5 else: new_pos = np.array(self.__win_size) - 10 self.help_panel._set_position(new_pos)
[docs] def build_scene(self): self.mem = GlobalHorizon() scene = window.Scene() scene.background(self.bg_color) return scene
def _show_force_render(self, _element): """ Callback function for lower level elements to force render. """ self.show_m.render() def _update_actors(self, actors): """Update actors in the scene. It essentially brings them forward in the stack. Parameters ---------- actors : list list of FURY actors. """ self._scene.rm(*actors) self._scene.add(*actors)
[docs] def build_show(self, scene): self._scene = scene title = f"Horizon {horizon_version}" self.show_m = window.ShowManager( scene=scene, title=title, size=(1920, 1080), reset_camera=False, order_transparent=self.order_transparent, ) if len(self.tractograms) > 0: if self.cluster: self.__clusters_visualizer = ClustersVisualizer( self.show_m, scene, self.tractograms ) color_ind = 0 for t, sft in enumerate(self.tractograms): streamlines = sft.streamlines if "tracts" in self.random_colors: colors = next(self.color_gen) else: colors = None if not self.world_coords: # TODO: Get affine from a StatefullTractogram raise ValueError( "Currently native coordinates are not supported for " "streamlines." ) if self.cluster: self.__clusters_visualizer.add_cluster_actors( t, streamlines, self.cluster_thr, colors ) else: if self.buan: colors = self.buan_colors[color_ind] streamline_actor = actor.line(streamlines, colors=colors) streamline_actor.GetProperty().SetEdgeVisibility(1) streamline_actor.GetProperty().SetRenderLinesAsTubes(1) streamline_actor.GetProperty().SetLineWidth(6) streamline_actor.GetProperty().SetOpacity(1) scene.add(streamline_actor) color_ind += 1 if self.cluster: # Information panel # It will be changed once all the elements wrapped in horizon # elements. text_block = build_label(HELP_MESSAGE, font_size=18) self.help_panel = ui.Panel2D( size=(300, 200), position=(1615, 875), color=(0.8, 0.8, 1.0), opacity=0.2, align="left", ) self.help_panel.add_element(text_block.obj, coords=(0.02, 0.01)) scene.add(self.help_panel) self.__tabs.append( ClustersTab(self.__clusters_visualizer, self.cluster_thr) ) sync_slices = sync_vol = False self.images = check_img_dtype(self.images) if len(self.images) > 0: if self.__roi_images: roi_color = self.__roi_colors roi_actors = [] img_count = 0 sync_slices, sync_vol = check_img_shapes(self.images) for img in self.images: title = f"Image {img_count + 1}" data, affine, fname = unpack_image(img) self.vox2ras = affine if self.__roi_images: if "rois" in self.random_colors: roi_color = next(self.color_gen) roi_actor = actor.contour_from_roi( data, affine=affine, color=roi_color ) scene.add(roi_actor) roi_actors.append(roi_actor) else: slices_viz = SlicesVisualizer( self.show_m.iren, scene, data, affine=affine, world_coords=self.world_coords, rgb=self.rgb, ) self.__tabs.append( SlicesTab( slices_viz, title, fname, force_render=self._show_force_render, ) ) img_count += 1 if len(roi_actors) > 0: self.__tabs.append(ROIsTab(roi_actors)) sync_peaks = False if len(self.pams) > 0: if self.images: sync_peaks = check_peak_size( self.pams, ref_img_shape=self.images[0][0].shape[:3], sync_imgs=sync_slices, ) else: sync_peaks = check_peak_size(self.pams) for pam in self.pams: peak_viz = PeaksVisualizer( (pam.peak_dirs, pam.affine), self.world_coords ) scene.add(peak_viz.actors[0]) self.__tabs.append(PeaksTab(peak_viz.actors[0])) if len(self._surfaces) > 0: for idx, surface in enumerate(self._surfaces): try: vertices, faces, fname = unpack_surface(surface) except ValueError as e: warn(str(e), stacklevel=2) continue color = next(self.color_gen) title = f"Surface {idx + 1}" surf_viz = SurfaceVisualizer((vertices, faces), scene, color) surf_tab = SurfaceTab(surf_viz, title, fname) self.__tabs.append(surf_tab) self.__win_size = scene.GetSize() if len(self.__tabs) > 0: self.__tab_mgr = TabManager( tabs=self.__tabs, win_size=scene.GetSize(), on_tab_changed=self._update_actors, add_to_scene=self._scene.add, remove_from_scene=self._scene.rm, sync_slices=sync_slices, sync_volumes=sync_vol, sync_peaks=sync_peaks, ) scene.add(self.__tab_mgr.tab_ui) self.__tab_mgr.handle_text_overflows() self.__tabs[-1].on_tab_selected() self.show_m.initialize() options = [ r"un\hide centroids", "invert selection", r"un\select all", "expand clusters", "collapse clusters", "save streamlines", "recluster", ] listbox = ui.ListBox2D( values=options, position=(10, 300), size=(200, 270), multiselection=False, font_size=18, ) def display_element(): action = listbox.selected[0] if action == r"un\hide centroids": self.__hide() if action == "invert selection": self.__invert() if action == r"un\select all": self.__show_all() if action == "expand clusters": self.__expand() if action == "collapse clusters": self.__reset() if action == "save streamlines": self.__save() if action == "recluster": self.__new_window() listbox.on_change = display_element listbox.panel.opacity = 0.2 listbox.set_visibility(0) self.show_m.scene.add(listbox) def left_click_centroid_callback(obj, event): self.cea[obj]["selected"] = not self.cea[obj]["selected"] self.cla[self.cea[obj]["cluster_actor"]]["selected"] = self.cea[obj][ "selected" ] self.show_m.render() def right_click_centroid_callback(obj, event): for lactor in listbox._get_actors(): lactor.SetVisibility(not lactor.GetVisibility()) listbox.scroll_bar.set_visibility(False) self.show_m.render() def left_click_cluster_callback(obj, event): if self.cla[obj]["selected"]: self.cla[obj]["centroid_actor"].VisibilityOn() ca = self.cla[obj]["centroid_actor"] self.cea[ca]["selected"] = 0 obj.VisibilityOff() self.cea[ca]["expanded"] = 0 self.show_m.render() def right_click_cluster_callback(obj, event): print("Cluster Area Selected") self.show_m.render() for cl in self.cla: cl.AddObserver("LeftButtonPressEvent", left_click_cluster_callback, 1.0) cl.AddObserver("RightButtonPressEvent", right_click_cluster_callback, 1.0) self.cla[cl]["centroid_actor"].AddObserver( "LeftButtonPressEvent", left_click_centroid_callback, 1.0 ) self.cla[cl]["centroid_actor"].AddObserver( "RightButtonPressEvent", right_click_centroid_callback, 1.0 ) self.mem.window_timer_cnt = 0 def timer_callback(obj, event): self.mem.window_timer_cnt += 1 # TODO possibly add automatic rotation option # self.show_m.scene.azimuth(0.01 * self.mem.window_timer_cnt) # self.show_m.render() scene.reset_camera() scene.zoom(1.5) scene.reset_clipping_range() if self.interactive: self.show_m.add_window_callback(self.__win_callback) self.show_m.add_timer_callback(True, 200, timer_callback) self.show_m.iren.AddObserver("KeyPressEvent", self.__key_press_events) if self.return_showm: return self.show_m if self.recorded_events is None: self.show_m.render() self.show_m.start() else: # set to True if event recorded file needs updating recording = False recording_filename = self.recorded_events if recording: self.show_m.record_events_to_file(recording_filename) else: self.show_m.play_events_from_file(recording_filename) else: window.record( scene=scene, out_path=self.out_png, size=(1200, 900), reset_camera=False )
[docs] @warning_for_keywords() def horizon( *, tractograms=None, images=None, pams=None, surfaces=None, cluster=False, rgb=False, cluster_thr=15.0, random_colors=None, bg_color=(0, 0, 0), order_transparent=True, length_gt=0, length_lt=1000, clusters_gt=0, clusters_lt=10000, world_coords=True, interactive=True, buan=False, buan_colors=None, roi_images=False, roi_colors=(1, 0, 0), out_png="tmp.png", recorded_events=None, return_showm=False, ): """Interactive medical visualization - Invert the Horizon! See :footcite:p:`Garyfallidis2019` for further details about Horizon. Parameters ---------- tractograms : sequence of StatefulTractograms StatefulTractograms are used for making sure that the coordinate systems are correct images : sequence of tuples Each tuple contains data and affine pams : sequence of PeakAndMetrics Contains peak directions and spherical harmonic coefficients surfaces : sequence of tuples Each tuple contains vertices and faces cluster : bool Enable QuickBundlesX clustering rgb: bool, optional Enable the color image. cluster_thr : float Distance threshold used for clustering. Default value 15.0 for small animal data you may need to use something smaller such as 2.0. The threshold is in mm. For this parameter to be active ``cluster`` should be enabled. random_colors : string Given multiple tractograms and/or ROIs then each tractogram and/or ROI will be shown with different color. If no value is provided both the tractograms and the ROIs will have a different random color generated from a distinguishable colormap. If the effect should only be applied to one of the 2 objects, then use the options 'tracts' and 'rois' for the tractograms and the ROIs respectively. bg_color : ndarray or list or tuple Define the background color of the scene. Default is black (0, 0, 0) order_transparent : bool Default True. Use depth peeling to sort transparent objects. If True also enables anti-aliasing. length_gt : float Clusters with average length greater than ``length_gt`` amount in mm will be shown. length_lt : float Clusters with average length less than ``length_lt`` amount in mm will be shown. clusters_gt : int Clusters with size greater than ``clusters_gt`` will be shown. clusters_lt : int Clusters with size less than ``clusters_lt`` will be shown. world_coords : bool Show data in their world coordinates (not native voxel coordinates) Default True. interactive : bool Allow user interaction. If False then Horizon goes on stealth mode and just saves pictures. buan : bool, optional Enables BUAN framework visualization. Default is False. buan_colors : list, optional List of colors for bundles. roi_images : bool, optional Displays binary images as contours. Default is False. roi_colors : ndarray or list or tuple, optional Define the color of the roi images. Default is red (1, 0, 0) out_png : string Filename of saved picture. recorded_events : string File path to replay recorded events return_showm : bool Return ShowManager object. Used only at Python level. Can be used for extending Horizon's capabilities externally and for testing purposes. References ---------- .. footbibliography:: """ hz = Horizon( tractograms=tractograms, images=images, pams=pams, surfaces=surfaces, cluster=cluster, rgb=rgb, cluster_thr=cluster_thr, random_colors=random_colors, length_gt=length_gt, length_lt=length_lt, clusters_gt=clusters_gt, clusters_lt=clusters_lt, world_coords=world_coords, interactive=interactive, out_png=out_png, recorded_events=recorded_events, return_showm=return_showm, bg_color=bg_color, order_transparent=order_transparent, buan=buan, buan_colors=buan_colors, roi_images=roi_images, roi_colors=roi_colors, ) scene = hz.build_scene() if return_showm: return hz.build_show(scene) hz.build_show(scene)