Source code for dipy.viz.projections

"""

Visualization tools for 2D projections of 3D functions on the sphere, such as
ODFs.

"""

import numpy as np
import scipy.interpolate as interp

import dipy.core.geometry as geo
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.optpkg import optional_package

matplotlib, has_mpl, setup_module = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")
mpl_tri, _, _ = optional_package("matplotlib.tri")
bm, has_basemap, _ = optional_package("mpl_toolkits.basemap")


[docs] @warning_for_keywords() @doctest_skip_parser def sph_project( vertices, val, *, ax=None, vmin=None, vmax=None, cmap=None, cbar=True, tri=False, boundary=False, **basemap_args, ): """Draw a signal on a 2D projection of the sphere. Parameters ---------- vertices : (N,3) ndarray Unit vector points of the sphere val : (N) ndarray Function values. ax : mpl axis, optional If specified, draw onto this existing axis instead. vmin : float, optional Minimum value to cut the z. vmax : float, optional Minimum value to cut the z. cmap : matplotlib.colors.Colormap, optional Colormap. cbar : bool, optional Whether to add the color-bar to the figure. tri : bool, optional Whether to display the plot triangulated as a pseudo-color plot. boundary : bool, optional Whether to draw the boundary around the projection in a black line. Returns ------- ax : axis Matplotlib figure axis Examples -------- >>> from dipy.data import default_sphere >>> verts = default_sphere.vertices >>> _ax = sph_project(verts.T, np.random.rand(len(verts.T))) # skip if not has_basemap """ # noqa: E501 if ax is None: fig, ax = plt.subplots(1) else: fig = plt.subplots(1) fig.axes.append(ax) if cmap is None: cmap = matplotlib.cm.hot basemap_args.setdefault("projection", "ortho") basemap_args.setdefault("lat_0", 0) basemap_args.setdefault("lon_0", 0) basemap_args.setdefault("resolution", "c") from mpl_toolkits.basemap import Basemap m = Basemap(**basemap_args) if boundary: m.drawmapboundary() # Rotate the coordinate system so that you are looking from the north pole: verts_rot = np.array(np.dot(np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]]), vertices)) # To get the orthographic projection, when the first coordinate is # positive: neg_idx = np.where(verts_rot[0] > 0) # rotate the entire b-vector around to point in the other direction: verts_rot[:, neg_idx] *= -1 _, theta, phi = geo.cart2sphere(verts_rot[0], verts_rot[1], verts_rot[2]) lat, lon = geo.sph2latlon(theta, phi) x, y = m(lon, lat) my_min = np.nanmin(val) if vmin is not None: my_min = vmin my_max = np.nanmax(val) if vmax is not None: my_max = vmax if tri: m.pcolor(x, y, val, vmin=my_min, vmax=my_max, tri=True, cmap=cmap) else: cmap_data = cmap._segmentdata red_interp, blue_interp, green_interp = ( interp.interp1d( np.array(cmap_data[gun])[:, 0], np.array(cmap_data[gun])[:, 1] ) for gun in ["red", "blue", "green"] ) r = (val - my_min) / float(my_max - my_min) # Enforce the maximum and minimum boundaries, if there are values # outside those boundaries: r[r < 0] = 0 r[r > 1] = 1 for this_x, this_y, this_r in zip(x, y, r): red = red_interp(this_r) blue = blue_interp(this_r) green = green_interp(this_r) m.plot(this_x, this_y, "o", c=[red.item(), green.item(), blue.item()]) if cbar: mappable = matplotlib.cm.ScalarMappable(cmap=cmap) mappable.set_array([my_min, my_max]) # setup colorbar axes instance. pos = ax.get_position() ell, b, w, h = pos.bounds # setup colorbar axes cax = fig.add_axes([ell + w + 0.075, b, 0.05, h], frameon=False) fig.colorbar(mappable, cax=cax) # draw colorbar return ax