"""
plotting functions
"""
from warnings import warn
import numpy as np
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package
plt, have_plt, _ = optional_package("matplotlib.pyplot")
[docs]
@warning_for_keywords()
def compare_maps(
fits,
maps,
*,
transpose=None,
fit_labels=None,
map_labels=None,
fit_kwargs=None,
map_kwargs=None,
filename=None,
):
"""Compare one or more scalar maps for different fits or models.
Parameters
----------
fits : list
List of fits to be compared.
maps : list
Names of attributes to be compared.
Default: 'rtop'.
transpose : bool, optional
If False, different fits are placed on different rows and different
maps on different columns. If True, the order is transposed. If None,
the figures are placed such that there are more columns than rows.
Default: None.
fit_labels : list, optional
Labels for the different fitting routines. If None the fits are labeled
by number.
Default: None.
map_labels : list, optional
Labels for the different attributes. If None the attribute names are
used.
Default: None.
fit_kwargs : list or dict, optional
A dict or list of dicts with imshow options for each fitting routine.
The dicts are passed to imshow as keyword-argument pairs.
Default: {}.
map_kwargs : list or dict, optional
A dict or list of dicts with imshow options for each MAP-MRI scalar.
The dicts are passed to imshow as keyword-argument pairs.
Default: {}.
filename : string, optional
Filename where the image will be saved.
Default: None.
"""
fit_kwargs = fit_kwargs or {}
map_kwargs = map_kwargs or {}
if not have_plt:
raise ValueError("matplotlib package needed for visualization.")
fontsize = "large"
xscale, yscale = 12, 10
m = len(fits)
n = len(maps)
if transpose is None:
transpose = m > n
if fit_labels is None:
fit_labels = [f"Fit {i + 1}" for i in range(m)]
if map_labels is None:
map_labels = maps
if isinstance(fit_kwargs, dict):
fit_kwargs = [fit_kwargs] * m
if isinstance(map_kwargs, dict):
map_kwargs = [map_kwargs] * n
if transpose:
fig, ax = plt.subplots(n, m, figsize=(xscale, yscale / m * n), squeeze=False)
ax = ax.T
for i in range(m):
ax[i, 0].set_title(fit_labels[i], fontsize=fontsize)
for j in range(n):
ax[0, j].set_ylabel(map_labels[j], fontsize=fontsize)
else:
fig, ax = plt.subplots(m, n, figsize=(xscale, yscale / n * m), squeeze=False)
for i in range(m):
ax[i, 0].set_ylabel(fit_labels[i], fontsize=fontsize)
for j in range(n):
ax[0, j].set_title(map_labels[j], fontsize=fontsize)
for i in range(m):
for j in range(n):
try:
attr = getattr(fits[i], maps[j])
if callable(attr):
attr = attr()
except AttributeError:
warn(f"Could not recover attribute {maps[j]}.", stacklevel=2)
attr = np.zeros((2, 2))
data = np.squeeze(np.array(attr, dtype=float)).T
ax[i, j].imshow(
data,
interpolation="nearest",
origin="lower",
cmap="gray",
**fit_kwargs[i],
**map_kwargs[j],
)
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
ax[i, j].spines["top"].set_visible(False)
ax[i, j].spines["right"].set_visible(False)
ax[i, j].spines["bottom"].set_visible(False)
ax[i, j].spines["left"].set_visible(False)
fig.tight_layout()
if filename:
plt.savefig(filename)
else:
plt.show()
[docs]
@warning_for_keywords()
def compare_qti_maps(
gt,
fit1,
fit2,
mask,
*,
maps=("fa", "ufa"),
fitname=("QTI", "QTI+"),
xlimits=([0, 1], [0.4, 1.5]),
disprange=([0, 1], [0, 1]),
slice=13,
):
"""Compare one or more qti derived maps obtained with
different fitting routines.
Parameters
----------
gt : qti fit object
The qti fit to be considered as ground truth
fit1 : qti fit object
First qti fit to be compared
fit2 : qti fit object
Second qti fit to be compared
mask : np.ndarray
Boolean array indicating which voxels to retain for comparing
the values
maps : array-like, optional
QTI invariants to be compared
fitname : array-like, optional
Names of the used QTI fitting routines
xlimits : array-like, optional
X-Axis limits for the histograms visualization
disprange : array-like, optional
Display range for maps
slice : int, optional
Axial brain slice to be visualized
"""
if not have_plt:
raise ValueError("matplotlib package needed for visualization")
n = len(maps)
fig, ax = plt.subplots(n, 4, figsize=(12, 9))
background = np.zeros(gt.S0_hat.shape[0:2])
for i in range(n):
for j in range(3):
ax[i, j].imshow(background, cmap="gray")
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
for k in range(n):
ax[k, 0].imshow(
np.rot90(getattr(gt, maps[k])[:, :, slice]),
cmap="gray",
vmin=disprange[k][0],
vmax=disprange[k][1],
)
ax[k, 0].set_title("GROUND TRUTH")
ax[k, 0].set_ylabel(maps[k], fontsize=20)
ax[k, 1].imshow(
np.rot90(getattr(fit1, maps[k])[:, :, slice]),
cmap="gray",
vmin=disprange[k][0],
vmax=disprange[k][1],
)
ax[k, 1].set_title(fitname[0])
ax[k, 2].imshow(
np.rot90(getattr(fit2, maps[k])[:, :, slice]),
cmap="gray",
vmin=disprange[k][0],
vmax=disprange[k][1],
)
ax[k, 2].set_title(fitname[1])
ax[k, 3].hist(
(getattr(fit1, maps[k])[mask, slice]).flatten(),
density=True,
bins=40,
label=fitname[0],
)
ax[k, 3].hist(
(getattr(fit2, maps[k])[mask, slice]).flatten(),
density=True,
bins=40,
label=fitname[1],
alpha=0.7,
)
ax[k, 3].hist(
(getattr(gt, maps[k])[mask, slice]).flatten(),
histtype="stepfilled",
density=True,
bins=40,
label="GT",
ec="k",
alpha=1,
linewidth=1.5,
fc="None",
)
ax[k, 3].legend()
ax[k, 3].set_title("VALUE DISTRIBUTION")
ax[k, 3].set_xlim(xlimits[k])
fig.tight_layout()
plt.show()
[docs]
def bundle_shape_profile(x, shape_profile, std):
"""Plot bundlewarp bundle shape profile.
Parameters
----------
x : np.ndarray
Integer array containing x-axis
shape_profile : np.ndarray
Float array containing bundlewarp displacement magnitudes along the
length of the bundle
std : np.ndarray
Float array containing standard deviations
"""
fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
std_1 = shape_profile + std
std_2 = shape_profile - std
ax.plot(
x, shape_profile, "-", label="Mean", color="Purple", linewidth=3, markersize=12
)
ax.fill_between(x, std_1, std_2, alpha=0.2, label="Std", color="Purple")
plt.xticks(x)
plt.ylim(0, max(std_1) + 2)
plt.ylabel("Average Displacement")
plt.xlabel("Segment Number")
plt.title("Bundle Shape Profile")
plt.legend(loc=2)
plt.show()
[docs]
def image_mosaic(
images, *, ax_labels=None, ax_kwargs=None, figsize=None, filename=None
):
"""
Draw a mosaic of 2D images using pyplot.imshow(). A colorbar is drawn
beside each image.
Parameters
----------
images: list of ndarray
Images to render.
ax_labels: list of str, optional
Label for each image.
ax_kwargs: list of dictionaries, optional
keyword arguments passed to imshow for each image. One dictionary per
image.
figsize: tuple of ints, optional
Figure size.
filename: str, optional
When given, figure is saved to disk under this name.
Returns
-------
fig: pyplot.Figure
The figure.
ax: pyplot.Axes or array of Axes
The subplots for each image.
"""
fig, ax = plt.subplots(1, len(images), figsize=figsize)
aximages = []
for it, (im, axe, kw) in enumerate(zip(images, ax, ax_kwargs)):
aximages.append(axe.imshow(im, **kw))
if ax_labels is not None:
axe.set_title(ax_labels[it])
for it, aximage in enumerate(aximages):
fig.colorbar(aximage, ax=ax[it])
if filename is not None:
plt.savefig(filename)
else:
plt.show()
return fig, ax