"""
Visualization tools for 2D projections of 3D functions on the sphere, such as
ODFs.
"""
import numpy as np
import scipy.interpolate as interp
import dipy.core.geometry as geo
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.optpkg import optional_package
matplotlib, has_mpl, setup_module = optional_package("matplotlib")
plt, _, _ = optional_package("matplotlib.pyplot")
mpl_tri, _, _ = optional_package("matplotlib.tri")
bm, has_basemap, _ = optional_package("mpl_toolkits.basemap")
[docs]
@warning_for_keywords()
@doctest_skip_parser
def sph_project(
vertices,
val,
*,
ax=None,
vmin=None,
vmax=None,
cmap=None,
cbar=True,
tri=False,
boundary=False,
**basemap_args,
):
"""Draw a signal on a 2D projection of the sphere.
Parameters
----------
vertices : (N,3) ndarray
Unit vector points of the sphere
val : (N) ndarray
Function values.
ax : mpl axis, optional
If specified, draw onto this existing axis instead.
vmin : float, optional
Minimum value to cut the z.
vmax : float, optional
Minimum value to cut the z.
cmap : matplotlib.colors.Colormap, optional
Colormap.
cbar : bool, optional
Whether to add the color-bar to the figure.
tri : bool, optional
Whether to display the plot triangulated as a pseudo-color plot.
boundary : bool, optional
Whether to draw the boundary around the projection in a black line.
Returns
-------
ax : axis
Matplotlib figure axis
Examples
--------
>>> from dipy.data import default_sphere
>>> verts = default_sphere.vertices
>>> _ax = sph_project(verts.T, np.random.rand(len(verts.T))) # skip if not has_basemap
""" # noqa: E501
if ax is None:
fig, ax = plt.subplots(1)
else:
fig = plt.subplots(1)
fig.axes.append(ax)
if cmap is None:
cmap = matplotlib.cm.hot
basemap_args.setdefault("projection", "ortho")
basemap_args.setdefault("lat_0", 0)
basemap_args.setdefault("lon_0", 0)
basemap_args.setdefault("resolution", "c")
from mpl_toolkits.basemap import Basemap
m = Basemap(**basemap_args)
if boundary:
m.drawmapboundary()
# Rotate the coordinate system so that you are looking from the north pole:
verts_rot = np.array(np.dot(np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]]), vertices))
# To get the orthographic projection, when the first coordinate is
# positive:
neg_idx = np.where(verts_rot[0] > 0)
# rotate the entire b-vector around to point in the other direction:
verts_rot[:, neg_idx] *= -1
_, theta, phi = geo.cart2sphere(verts_rot[0], verts_rot[1], verts_rot[2])
lat, lon = geo.sph2latlon(theta, phi)
x, y = m(lon, lat)
my_min = np.nanmin(val)
if vmin is not None:
my_min = vmin
my_max = np.nanmax(val)
if vmax is not None:
my_max = vmax
if tri:
m.pcolor(x, y, val, vmin=my_min, vmax=my_max, tri=True, cmap=cmap)
else:
cmap_data = cmap._segmentdata
red_interp, blue_interp, green_interp = (
interp.interp1d(
np.array(cmap_data[gun])[:, 0], np.array(cmap_data[gun])[:, 1]
)
for gun in ["red", "blue", "green"]
)
r = (val - my_min) / float(my_max - my_min)
# Enforce the maximum and minimum boundaries, if there are values
# outside those boundaries:
r[r < 0] = 0
r[r > 1] = 1
for this_x, this_y, this_r in zip(x, y, r):
red = red_interp(this_r)
blue = blue_interp(this_r)
green = green_interp(this_r)
m.plot(this_x, this_y, "o", c=[red.item(), green.item(), blue.item()])
if cbar:
mappable = matplotlib.cm.ScalarMappable(cmap=cmap)
mappable.set_array([my_min, my_max])
# setup colorbar axes instance.
pos = ax.get_position()
ell, b, w, h = pos.bounds
# setup colorbar axes
cax = fig.add_axes([ell + w + 0.075, b, 0.05, h], frameon=False)
fig.colorbar(mappable, cax=cax) # draw colorbar
return ax