"""
plotting functions
"""
from warnings import warn
import numpy as np
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package
plt, have_plt, _ = optional_package("matplotlib.pyplot")
[docs]
@warning_for_keywords()
def compare_maps(
    fits,
    maps,
    *,
    transpose=None,
    fit_labels=None,
    map_labels=None,
    fit_kwargs=None,
    map_kwargs=None,
    filename=None,
):
    """Compare one or more scalar maps for different fits or models.
    Parameters
    ----------
    fits : list
        List of fits to be compared.
    maps : list
        Names of attributes to be compared.
        Default: 'rtop'.
    transpose : bool, optional
        If False, different fits are placed on different rows and different
        maps on different columns. If True, the order is transposed. If None,
        the figures are placed such that there are more columns than rows.
        Default: None.
    fit_labels : list, optional
        Labels for the different fitting routines. If None the fits are labeled
        by number.
        Default: None.
    map_labels : list, optional
        Labels for the different attributes. If None the attribute names are
        used.
        Default: None.
    fit_kwargs : list or dict, optional
        A dict or list of dicts with imshow options for each fitting routine.
        The dicts are passed to imshow as keyword-argument pairs.
        Default: {}.
    map_kwargs : list or dict, optional
        A dict or list of dicts with imshow options for each MAP-MRI scalar.
        The dicts are passed to imshow as keyword-argument pairs.
        Default: {}.
    filename : string, optional
        Filename where the image will be saved.
        Default: None.
    """
    fit_kwargs = fit_kwargs or {}
    map_kwargs = map_kwargs or {}
    if not have_plt:
        raise ValueError("matplotlib package needed for visualization.")
    fontsize = "large"
    xscale, yscale = 12, 10
    m = len(fits)
    n = len(maps)
    if transpose is None:
        transpose = m > n
    if fit_labels is None:
        fit_labels = [f"Fit {i + 1}" for i in range(m)]
    if map_labels is None:
        map_labels = maps
    if isinstance(fit_kwargs, dict):
        fit_kwargs = [fit_kwargs] * m
    if isinstance(map_kwargs, dict):
        map_kwargs = [map_kwargs] * n
    if transpose:
        fig, ax = plt.subplots(n, m, figsize=(xscale, yscale / m * n), squeeze=False)
        ax = ax.T
        for i in range(m):
            ax[i, 0].set_title(fit_labels[i], fontsize=fontsize)
        for j in range(n):
            ax[0, j].set_ylabel(map_labels[j], fontsize=fontsize)
    else:
        fig, ax = plt.subplots(m, n, figsize=(xscale, yscale / n * m), squeeze=False)
        for i in range(m):
            ax[i, 0].set_ylabel(fit_labels[i], fontsize=fontsize)
        for j in range(n):
            ax[0, j].set_title(map_labels[j], fontsize=fontsize)
    for i in range(m):
        for j in range(n):
            try:
                attr = getattr(fits[i], maps[j])
                if callable(attr):
                    attr = attr()
            except AttributeError:
                warn(f"Could not recover attribute {maps[j]}.", stacklevel=2)
                attr = np.zeros((2, 2))
            data = np.squeeze(np.array(attr, dtype=float)).T
            ax[i, j].imshow(
                data,
                interpolation="nearest",
                origin="lower",
                cmap="gray",
                **fit_kwargs[i],
                **map_kwargs[j],
            )
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])
            ax[i, j].spines["top"].set_visible(False)
            ax[i, j].spines["right"].set_visible(False)
            ax[i, j].spines["bottom"].set_visible(False)
            ax[i, j].spines["left"].set_visible(False)
    fig.tight_layout()
    if filename:
        plt.savefig(filename)
    else:
        plt.show() 
[docs]
@warning_for_keywords()
def compare_qti_maps(
    gt,
    fit1,
    fit2,
    mask,
    *,
    maps=("fa", "ufa"),
    fitname=("QTI", "QTI+"),
    xlimits=([0, 1], [0.4, 1.5]),
    disprange=([0, 1], [0, 1]),
    slice=13,
):
    """Compare one or more qti derived maps obtained with
    different fitting routines.
    Parameters
    ----------
    gt : qti fit object
        The qti fit to be considered as ground truth
    fit1 : qti fit object
        First qti fit to be compared
    fit2 : qti fit object
        Second qti fit to be compared
    mask : np.ndarray
        Boolean array indicating which voxels to retain for comparing
        the values
    maps : array-like, optional
        QTI invariants to be compared
    fitname : array-like, optional
        Names of the used QTI fitting routines
    xlimits : array-like, optional
        X-Axis limits for the histograms visualization
    disprange : array-like, optional
        Display range for maps
    slice : int, optional
        Axial brain slice to be visualized
    """
    if not have_plt:
        raise ValueError("matplotlib package needed for visualization")
    n = len(maps)
    fig, ax = plt.subplots(n, 4, figsize=(12, 9))
    background = np.zeros(gt.S0_hat.shape[0:2])
    for i in range(n):
        for j in range(3):
            ax[i, j].imshow(background, cmap="gray")
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])
    for k in range(n):
        ax[k, 0].imshow(
            np.rot90(getattr(gt, maps[k])[:, :, slice]),
            cmap="gray",
            vmin=disprange[k][0],
            vmax=disprange[k][1],
        )
        ax[k, 0].set_title("GROUND TRUTH")
        ax[k, 0].set_ylabel(maps[k], fontsize=20)
        ax[k, 1].imshow(
            np.rot90(getattr(fit1, maps[k])[:, :, slice]),
            cmap="gray",
            vmin=disprange[k][0],
            vmax=disprange[k][1],
        )
        ax[k, 1].set_title(fitname[0])
        ax[k, 2].imshow(
            np.rot90(getattr(fit2, maps[k])[:, :, slice]),
            cmap="gray",
            vmin=disprange[k][0],
            vmax=disprange[k][1],
        )
        ax[k, 2].set_title(fitname[1])
        ax[k, 3].hist(
            (getattr(fit1, maps[k])[mask, slice]).flatten(),
            density=True,
            bins=40,
            label=fitname[0],
        )
        ax[k, 3].hist(
            (getattr(fit2, maps[k])[mask, slice]).flatten(),
            density=True,
            bins=40,
            label=fitname[1],
            alpha=0.7,
        )
        ax[k, 3].hist(
            (getattr(gt, maps[k])[mask, slice]).flatten(),
            histtype="stepfilled",
            density=True,
            bins=40,
            label="GT",
            ec="k",
            alpha=1,
            linewidth=1.5,
            fc="None",
        )
        ax[k, 3].legend()
        ax[k, 3].set_title("VALUE DISTRIBUTION")
        ax[k, 3].set_xlim(xlimits[k])
    fig.tight_layout()
    plt.show() 
[docs]
def bundle_shape_profile(x, shape_profile, std):
    """Plot bundlewarp bundle shape profile.
    Parameters
    ----------
    x : np.ndarray
        Integer array containing x-axis
    shape_profile : np.ndarray
        Float array containing bundlewarp displacement magnitudes along the
        length of the bundle
    std : np.ndarray
        Float array containing standard deviations
    """
    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
    std_1 = shape_profile + std
    std_2 = shape_profile - std
    ax.plot(
        x, shape_profile, "-", label="Mean", color="Purple", linewidth=3, markersize=12
    )
    ax.fill_between(x, std_1, std_2, alpha=0.2, label="Std", color="Purple")
    plt.xticks(x)
    plt.ylim(0, max(std_1) + 2)
    plt.ylabel("Average Displacement")
    plt.xlabel("Segment Number")
    plt.title("Bundle Shape Profile")
    plt.legend(loc=2)
    plt.show() 
[docs]
def image_mosaic(
    images, *, ax_labels=None, ax_kwargs=None, figsize=None, filename=None
):
    """
    Draw a mosaic of 2D images using pyplot.imshow(). A colorbar is drawn
    beside each image.
    Parameters
    ----------
    images: list of ndarray
        Images to render.
    ax_labels: list of str, optional
        Label for each image.
    ax_kwargs: list of dictionaries, optional
        keyword arguments passed to imshow for each image. One dictionary per
        image.
    figsize: tuple of ints, optional
        Figure size.
    filename: str, optional
        When given, figure is saved to disk under this name.
    Returns
    -------
    fig: pyplot.Figure
        The figure.
    ax: pyplot.Axes or array of Axes
        The subplots for each image.
    """
    fig, ax = plt.subplots(1, len(images), figsize=figsize)
    aximages = []
    for it, (im, axe, kw) in enumerate(zip(images, ax, ax_kwargs)):
        aximages.append(axe.imshow(im, **kw))
        if ax_labels is not None:
            axe.set_title(ax_labels[it])
    for it, aximage in enumerate(aximages):
        fig.colorbar(aximage, ax=ax[it])
    if filename is not None:
        plt.savefig(filename)
    else:
        plt.show()
    return fig, ax