Source code for dipy.viz.horizon.visualizer.peak

from os.path import join as pjoin
import warnings

import numpy as np

from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package

fury, has_fury, setup_module = optional_package("fury", min_version="0.9.0")
if has_fury:
    from fury.colormap import colormap_lookup_table
    from fury.lib import (
        VTK_OBJECT,
        Actor,
        CellArray,
        Command,
        PolyData,
        PolyDataMapper,
        calldata_type,
        numpy_support,
    )
    from fury.shaders import (
        attribute_to_actor,
        compose_shader,
        import_fury_shader,
        shader_to_actor,
    )
    from fury.utils import apply_affine, numpy_to_vtk_colors, numpy_to_vtk_points
else:

    class Actor:
        pass

    def calldata_type(func):
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        return wrapper

    def VTK_OBJECT(*args):
        pass


[docs] class PeakActor(Actor): """FURY actor for visualizing DWI peaks. Parameters ---------- directions : ndarray Peak directions. The shape of the array should be (X, Y, Z, D, 3). indices : tuple Indices given in tuple(x_indices, y_indices, z_indices) format for mapping 2D ODF array to 3D voxel grid. values : ndarray, optional Peak values. The shape of the array should be (X, Y, Z, D). affine : array, optional 4x4 transformation array from native coordinates to world coordinates. colors : None or string ('rgb_standard') or tuple (3D or 4D) or array/ndarray (N, 3 or 4) or (K, 3 or 4) or (N, ) or (K, ) If None a standard orientation colormap is used for every line. If one tuple of color is used. Then all streamlines will have the same color. If an array (N, 3 or 4) is given, where N is equal to the number of points. Then every point is colored with a different RGB(A) color. If an array (K, 3 or 4) is given, where K is equal to the number of lines. Then every line is colored with a different RGB(A) color. If an array (N, ) is given, where N is the number of points then these are considered as the values to be used by the colormap. If an array (K,) is given, where K is the number of lines then these are considered as the values to be used by the colormap. lookup_colormap : vtkLookupTable, optional Add a default lookup table to the colormap. Look at :func:`fury.actor.colormap_lookup_table` for more information. linewidth : float, optional Line thickness. symmetric: bool, optional If True, peaks are drawn for both peaks_dirs and -peaks_dirs. Else, peaks are only drawn for directions given by peaks_dirs. """ # noqa: E501 @warning_for_keywords() def __init__( self, directions, indices, *, values=None, affine=None, colors=None, lookup_colormap=None, linewidth=1, symmetric=True, ): if affine is not None: w_pos = apply_affine(affine, np.asarray(indices).T) valid_dirs = directions[indices] num_dirs = len(np.nonzero(np.abs(valid_dirs).max(axis=-1) > 0)[0]) pnts_per_line = 2 points_array = np.empty((num_dirs * pnts_per_line, 3)) centers_array = np.empty_like(points_array, dtype=int) diffs_array = np.empty_like(points_array) line_count = 0 for idx, center in enumerate(zip(indices[0], indices[1], indices[2])): if affine is None: xyz = np.asarray(center) else: xyz = w_pos[idx, :] valid_peaks = np.nonzero(np.abs(valid_dirs[idx, :, :]).max(axis=-1) > 0.0)[ 0 ] for direction in valid_peaks: if values is not None: pv = values[center][direction] else: pv = 1.0 if symmetric: point_i = directions[center][direction] * pv + xyz point_e = -directions[center][direction] * pv + xyz else: point_i = directions[center][direction] * pv + xyz point_e = xyz diff = point_e - point_i points_array[line_count * pnts_per_line, :] = point_e points_array[line_count * pnts_per_line + 1, :] = point_i centers_array[line_count * pnts_per_line, :] = center centers_array[line_count * pnts_per_line + 1, :] = center diffs_array[line_count * pnts_per_line, :] = diff diffs_array[line_count * pnts_per_line + 1, :] = diff line_count += 1 vtk_points = numpy_to_vtk_points(points_array) vtk_cells = _points_to_vtk_cells(points_array) colors_tuple = _peaks_colors_from_points(points_array, colors=colors) vtk_colors, colors_are_scalars, self.__global_opacity = colors_tuple poly_data = PolyData() poly_data.SetPoints(vtk_points) poly_data.SetLines(vtk_cells) poly_data.GetPointData().SetScalars(vtk_colors) self.__mapper = PolyDataMapper() self.__mapper.SetInputData(poly_data) self.__mapper.ScalarVisibilityOn() self.__mapper.SetScalarModeToUsePointFieldData() self.__mapper.SelectColorArray("colors") self.__mapper.Update() self.SetMapper(self.__mapper) attribute_to_actor(self, centers_array, "center") attribute_to_actor(self, diffs_array, "diff") vs_var_dec = """ in vec3 center; in vec3 diff; flat out vec3 centerVertexMCVSOutput; """ fs_var_dec = """ flat in vec3 centerVertexMCVSOutput; uniform bool isRange; uniform vec3 crossSection; uniform vec3 lowRanges; uniform vec3 highRanges; """ orient_to_rgb = import_fury_shader(pjoin("utils", "orient_to_rgb.glsl")) visible_cross_section = import_fury_shader( pjoin("interaction", "visible_cross_section.glsl") ) visible_range = import_fury_shader(pjoin("interaction", "visible_range.glsl")) vs_dec = compose_shader([vs_var_dec, orient_to_rgb]) fs_dec = compose_shader([fs_var_dec, visible_cross_section, visible_range]) vs_impl = """ centerVertexMCVSOutput = center; if (vertexColorVSOutput.rgb == vec3(0)) { vertexColorVSOutput.rgb = orient2rgb(diff); } """ fs_impl = """ if (isRange) { if (!inVisibleRange(centerVertexMCVSOutput)) discard; } else { if (!inVisibleCrossSection(centerVertexMCVSOutput)) discard; } """ shader_to_actor(self, "vertex", decl_code=vs_dec, impl_code=vs_impl) shader_to_actor(self, "fragment", decl_code=fs_dec) shader_to_actor(self, "fragment", impl_code=fs_impl, block="light") # Color scale with a lookup table if colors_are_scalars: if lookup_colormap is None: lookup_colormap = colormap_lookup_table() self.__mapper.SetLookupTable(lookup_colormap) self.__mapper.UseLookupTableScalarRangeOn() self.__mapper.Update() self.__lw = linewidth self.GetProperty().SetLineWidth(self.__lw) if self.__global_opacity >= 0: self.GetProperty().SetOpacity(self.__global_opacity) self.__min_centers = np.zeros(shape=(3,)) self.__max_centers = np.array(directions.shape[:3]) self.__is_range = True self.__low_ranges = self.__min_centers self.__high_ranges = self.__max_centers self.__cross_section = self.__high_ranges // 2 self.__mapper.AddObserver( Command.UpdateShaderEvent, self.__display_peaks_vtk_callback ) @calldata_type(VTK_OBJECT) def __display_peaks_vtk_callback(self, caller, event, calldata=None): if calldata is not None: calldata.SetUniformi("isRange", self.__is_range) calldata.SetUniform3f("highRanges", self.__high_ranges) calldata.SetUniform3f("lowRanges", self.__low_ranges) calldata.SetUniform3f("crossSection", self.__cross_section)
[docs] def display_cross_section(self, x, y, z): if self.__is_range: self.__is_range = False self.__cross_section = [x, y, z]
[docs] def display_extent(self, x1, x2, y1, y2, z1, z2): if not self.__is_range: self.__is_range = True self.__low_ranges = [x1, y1, z1] self.__high_ranges = [x2, y2, z2]
@property def cross_section(self): return self.__cross_section @property def global_opacity(self): return self.__global_opacity @global_opacity.setter def global_opacity(self, opacity): self.__global_opacity = opacity self.GetProperty().SetOpacity(self.__global_opacity) @property def high_ranges(self): return self.__high_ranges @property def is_range(self): return self.__is_range @property def low_ranges(self): return self.__low_ranges @property def linewidth(self): return self.__lw @linewidth.setter def linewidth(self, linewidth): self.__lw = linewidth self.GetProperty().SetLineWidth(self.__lw) @property def max_centers(self): return self.__max_centers @property def min_centers(self): return self.__min_centers
[docs] @warning_for_keywords() def peak( peaks_dirs, *, peaks_values=None, mask=None, affine=None, colors=None, linewidth=1, lookup_colormap=None, symmetric=True, ): """Visualize peak directions as given from ``peaks_from_model`` function. Parameters ---------- peaks_dirs : ndarray Peak directions. The shape of the array should be (X, Y, Z, D, 3). peaks_values : ndarray, optional Peak values. The shape of the array should be (X, Y, Z, D). affine : array, optional 4x4 transformation array from native coordinates to world coordinates. mask : ndarray, optional 3D mask colors : tuple or None, optional Default None. If None then every peak gets an orientation color in similarity to a DEC map. lookup_colormap : vtkLookupTable, optional Add a default lookup table to the colormap. Look at :func:`fury.actor.colormap_lookup_table` for more information. linewidth : float, optional Line thickness. Default is 1. symmetric : bool, optional If True, peaks are drawn for both peaks_dirs and -peaks_dirs. Else, peaks are only drawn for directions given by peaks_dirs. Default is True. Returns ------- peak_actor : PeakActor Actor or LODActor representing the peaks directions and/or magnitudes. """ if peaks_dirs.ndim != 5: raise ValueError( "Invalid peak directions. The shape of the structure " f"must be (XxYxZxDx3). Your data has {peaks_dirs.ndim} dimensions." "" ) if peaks_dirs.shape[4] != 3: raise ValueError( "Invalid peak directions. The shape of the last " "dimension must be 3. Your data has a last dimension " f"of {peaks_dirs.shape[4]}." ) dirs_shape = peaks_dirs.shape if peaks_values is not None: if peaks_values.ndim != 4: raise ValueError( "Invalid peak values. The shape of the structure " f"must be (XxYxZxD). Your data has {peaks_values.ndim} dimensions." ) vals_shape = peaks_values.shape if vals_shape != dirs_shape[:4]: raise ValueError( "Invalid peak values. The shape of the values " "must coincide with the shape of the directions." ) valid_mask = np.abs(peaks_dirs).max(axis=(-2, -1)) > 0 if mask is not None: if mask.ndim != 3: warnings.warn( "Invalid mask. The mask must be a 3D array. The " f"passed mask has {mask.ndim} dimensions. Ignoring passed " "mask.", UserWarning, stacklevel=2, ) elif mask.shape != dirs_shape[:3]: warnings.warn( "Invalid mask. The shape of the mask must coincide " "with the shape of the directions. Ignoring passed " "mask.", UserWarning, stacklevel=2, ) else: valid_mask = np.logical_and(valid_mask, mask) indices = np.where(valid_mask) return PeakActor( peaks_dirs, indices, values=peaks_values, affine=affine, colors=colors, lookup_colormap=lookup_colormap, linewidth=linewidth, symmetric=symmetric, )
@warning_for_keywords() def _peaks_colors_from_points(points, *, colors=None, points_per_line=2): """Return a VTK scalar array containing colors information for each one of the peaks according to the policy defined by the parameter colors. Parameters ---------- points : (N, 3) array or ndarray points coordinates array. colors : None or string ('rgb_standard') or tuple (3D or 4D) or array/ndarray (N, 3 or 4) or array/ndarray (K, 3 or 4) or array/ndarray(N, ) or array/ndarray (K, ) If None a standard orientation colormap is used for every line. If one tuple of color is used. Then all streamlines will have the same color. If an array (N, 3 or 4) is given, where N is equal to the number of points. Then every point is colored with a different RGB(A) color. If an array (K, 3 or 4) is given, where K is equal to the number of lines. Then every line is colored with a different RGB(A) color. If an array (N, ) is given, where N is the number of points then these are considered as the values to be used by the colormap. If an array (K,) is given, where K is the number of lines then these are considered as the values to be used by the colormap. points_per_line : int (1 or 2), optional number of points per peak direction. Returns ------- color_array : vtkDataArray vtk scalar array with name 'colors'. colors_are_scalars : bool indicates whether or not the colors are scalars to be interpreted by a colormap. global_opacity : float returns 1 if the colors array doesn't contain opacity otherwise -1. """ num_pnts = len(points) num_lines = num_pnts // points_per_line colors_are_scalars = False global_opacity = 1 if colors is None or colors == "rgb_standard": # Automatic RGB colors colors = np.asarray((0, 0, 0)) color_array = numpy_to_vtk_colors(np.tile(255 * colors, (num_pnts, 1))) elif type(colors) is tuple: global_opacity = 1 if len(colors) == 3 else -1 colors = np.asarray(colors) color_array = numpy_to_vtk_colors(np.tile(255 * colors, (num_pnts, 1))) else: colors = np.asarray(colors) if len(colors) == num_lines: pnts_colors = np.repeat(colors, points_per_line, axis=0) if colors.ndim == 1: # Scalar per line color_array = numpy_support.numpy_to_vtk(pnts_colors, deep=True) colors_are_scalars = True elif colors.ndim == 2: # RGB(A) color per line global_opacity = 1 if colors.shape[1] == 3 else -1 color_array = numpy_to_vtk_colors(255 * pnts_colors) elif len(colors) == num_pnts: if colors.ndim == 1: # Scalar per point color_array = numpy_support.numpy_to_vtk(colors, deep=True) colors_are_scalars = True elif colors.ndim == 2: # RGB(A) color per point global_opacity = 1 if colors.shape[1] == 3 else -1 color_array = numpy_to_vtk_colors(255 * colors) color_array.SetName("colors") return color_array, colors_are_scalars, global_opacity @warning_for_keywords() def _points_to_vtk_cells(points, *, points_per_line=2): """Return the VTK cell array for the peaks given the set of points coordinates. Parameters ---------- points : (N, 3) array or ndarray points coordinates array. points_per_line : int (1 or 2), optional number of points per peak direction. Returns ------- cell_array : vtkCellArray connectivity + offset information. """ num_pnts = len(points) num_cells = num_pnts // points_per_line cell_array = CellArray() """ Connectivity is an array that contains the indices of the points that need to be connected in the visualization. The indices start from 0. """ connectivity = np.asarray(list(range(0, num_pnts)), dtype=int) """ Offset is an array that contains the indices of the first point of each line. The indices start from 0 and given the known geometry of this actor the creation of this array requires a 2 points padding between indices. """ offset = np.asarray(list(range(0, num_pnts + 1, points_per_line)), dtype=int) vtk_array_type = numpy_support.get_vtk_array_type(connectivity.dtype) cell_array.SetData( numpy_support.numpy_to_vtk(offset, deep=True, array_type=vtk_array_type), numpy_support.numpy_to_vtk(connectivity, deep=True, array_type=vtk_array_type), ) cell_array.SetNumberOfCells(num_cells) return cell_array
[docs] class PeaksVisualizer: def __init__(self, pam, world_coords): self._peak_dirs, self._affine = pam if world_coords: self._peak_actor = peak(self._peak_dirs, affine=self._affine) else: self._peak_actor = peak(self._peak_dirs) @property def actors(self): return [self._peak_actor]