import warnings
import numpy as np
from scipy.optimize import linear_sum_assignment
from dipy.align.bundlemin import distance_matrix_mdf
from dipy.align.cpd import DeformableRegistration
from dipy.align.streamlinear import slr_with_qbx
from dipy.segment.clustering import QuickBundles
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
from dipy.stats.analysis import assignment_map
from dipy.testing.decorators import warning_for_keywords
from dipy.tracking.streamline import Streamlines, length, unlist_streamlines
from dipy.utils.optpkg import optional_package
from dipy.viz.plotting import bundle_shape_profile
pd, have_pd, _ = optional_package("pandas")
[docs]
def average_bundle_length(bundle):
    """Find average Euclidean length of the bundle in mm.
    Parameters
    ----------
    bundle : Streamlines
        Bundle who's average length is to be calculated.
    Returns
    -------
    int
        Average Euclidean length of bundle in mm.
    """
    metric = AveragePointwiseEuclideanMetric()
    qb = QuickBundles(threshold=85.0, metric=metric)
    clusters = qb.cluster(bundle)
    centroids = Streamlines(clusters.centroids)
    return length(centroids)[0] 
[docs]
def find_missing(lst, cb):
    """Find unmatched streamline indices in moving bundle.
    Parameters
    ----------
    lst : List
        List of integers containing all the streamlines indices in moving
        bundle.
    cb : List
        List of integers containing streamline indices of the moving bundle
        that were not matched to any streamline in static bundle.
    Returns
    -------
    list
        List containing unmatched streamlines from moving bundle
    """
    return [x for x in range(0, len(cb)) if x not in lst] 
[docs]
@warning_for_keywords()
def bundlewarp(
    static, moving, *, dist=None, alpha=0.5, beta=20, max_iter=15, affine=True
):
    """Register two bundles using nonlinear method.
     See :footcite:p:`Chandio2023` for further details about the method.
    Parameters
    ----------
    static : Streamlines
        Reference/fixed bundle.
    moving : Streamlines
        Target bundle that will be moved/registered to match the static bundle.
    dist : float, optional
        Precomputed distance matrix.
    alpha : float, optional
        Represents the trade-off between regularizing the deformation and
        having points match very closely. Lower value of alpha means high
        deformations.
    beta : int, optional
        Represents the strength of the interaction between points
        Gaussian kernel size.
    max_iter : int, optional
        Maximum number of iterations for deformation process in ml-CPD method.
    affine : boolean, optional
        If False, use rigid registration as starting point.
    Returns
    -------
    deformed_bundle : Streamlines
        Nonlinearly moved bundle (warped bundle)
    moving_aligned : Streamlines
        Linearly moved bundle (affinely moved)
    dist : np.ndarray
        Float array containing distance between moving and static bundle
    matched_pairs : np.ndarray
        Int array containing streamline correspondences between two bundles
    warp : np.ndarray
        Nonlinear warp map generated by BundleWarp
    References
    ----------
    .. footbibliography::
    """
    if alpha <= 0.01:
        warnings.warn(
            "Using alpha<=0.01 will result in extreme deformations", stacklevel=2
        )
    if average_bundle_length(static) <= 50:
        beta = 10
    x0 = "affine" if affine else "rigid"
    moving_aligned, _, _, _ = slr_with_qbx(static, moving, x0=x0, rm_small_clusters=0)
    if dist is not None:
        print("using pre-computed distances")
    else:
        dist = distance_matrix_mdf(static, moving_aligned).T
    matched_pairs = np.zeros((len(moving), 2))
    matched_pairs1 = np.asarray(linear_sum_assignment(dist)).T
    for mt in matched_pairs1:
        matched_pairs[mt[0]] = mt
    num = len(matched_pairs1)
    all_pairs = list(matched_pairs1[:, 0])
    all_matched = False
    while all_matched is False:
        num = len(all_pairs)
        if num < len(moving):
            ml = find_missing(all_pairs, moving)
            dist2 = dist[:][ml]
            # dist2 has distance among unmatched streamlines of moving bundle
            # and all static bundle's streamlines
            matched_pairs2 = np.asarray(linear_sum_assignment(dist2)).T
            for i in range(matched_pairs2.shape[0]):
                matched_pairs2[i][0] = ml[matched_pairs2[i][0]]
            for mt in matched_pairs2:
                matched_pairs[mt[0]] = mt
            all_pairs.extend(matched_pairs2[:, 0])
            num2 = num + len(matched_pairs2)
            if num2 == len(moving):
                all_matched = True
                num = num2
        else:
            all_matched = True
    deformed_bundle = Streamlines([])
    warp = []
    # Iterate over each pair of streamlines and deform them
    # Append deformed streamlines in deformed_bundle
    for _, pairs in enumerate(matched_pairs):
        s1 = static[int(pairs[1])]
        s2 = moving_aligned[int(pairs[0])]
        static_s = s1
        moving_s = s2
        reg = DeformableRegistration(
            X=static_s, Y=moving_s, alpha=alpha, beta=beta, max_iterations=max_iter
        )
        ty, pr = reg.register()
        ty = ty.astype(float)
        deformed_bundle.append(ty)
        warp.append(pr)
    warp = pd.DataFrame(warp, columns=["gaussian_kernel", "transforms"])
    # Returns deformed bundle, affinely moved bundle, distance matrix,
    # streamline correspondences, and warp field
    return deformed_bundle, moving_aligned, dist, matched_pairs, warp 
[docs]
def bundlewarp_vector_filed(moving_aligned, deformed_bundle):
    """Calculate vector fields.
    Vector field computation as the difference between each streamline point
    in the deformed and linearly aligned bundles
    Parameters
    ----------
    moving_aligned : Streamlines
        Linearly (affinely) moved bundle
    deformed_bundle : Streamlines
        Nonlinearly (warped) bundle
    Returns
    -------
    offsets : List
        Vector field modules
    directions : List
        Unitary vector directions
    colors : List
        Colors for bundle warping field vectors. Colors follow the convention
        used in DTI-derived maps (e.g. color FA) :footcite:p:`Pajevic1999`.
    References
    ----------
    .. footbibliography::
    """
    points_aligned, _ = unlist_streamlines(moving_aligned)
    points_deformed, _ = unlist_streamlines(deformed_bundle)
    vector_field = points_deformed - points_aligned
    offsets = np.sqrt(np.sum((vector_field) ** 2, 1))  # vector field modules
    # Normalize vectors to be unitary (directions)
    directions = vector_field / np.array([offsets]).T
    # Define colors mapping the direction vectors to RGB.
    # Absolute value generates DTI-like colors
    colors = directions
    return offsets, directions, colors 
[docs]
@warning_for_keywords()
def bundlewarp_shape_analysis(
    moving_aligned, deformed_bundle, *, no_disks=10, plotting=False
):
    """Calculate bundle shape difference profile.
    Bundle shape difference analysis using magnitude from BundleWarp
    displacements and BUAN.
    Depending on the number of points of a streamline, and the number of
    segments requested, multiple points may be considered for the computation
    of a given segment; a segment may contain information from a single point;
    or some segments may not contain information from any points. In the latter
    case, the segment will contain an ``np.nan`` value. The point-to-segment
    mapping is defined by the :func:`assignment_map`: for each segment index,
    the point information of the matching index positions, as returned by
    :func:`assignment_map`, are considered for the computation.
    Parameters
    ----------
    moving_aligned : Streamlines
        Linearly (affinely) moved bundle
    deformed_bundle : Streamlines
        Nonlinearly (warped) moved bundle
    no_disks : int, optional
        Number of segments to be created along the length of the bundle
    plotting : Boolean, optional
        Plot bundle shape profile
    Returns
    -------
    shape_profile : np.ndarray
        Float array containing bundlewarp displacement magnitudes along the
        length of the bundle
    stdv : np.ndarray
        Float array containing standard deviations
    """
    n = no_disks
    offsets, directions, colors = bundlewarp_vector_filed(
        moving_aligned, deformed_bundle
    )
    indx = assignment_map(deformed_bundle, deformed_bundle, n)
    indx = np.array(indx)
    rng = np.random.default_rng()
    colors = rng.random((n, 3))
    disks_color = []
    for _, ind in enumerate(indx):
        disks_color.append(tuple(colors[ind]))
    x = np.array(range(1, n + 1))
    shape_profile = np.zeros(n)
    stdv = np.zeros(n)
    for i in range(n):
        mask = indx == i
        if sum(mask):
            shape_profile[i] = np.mean(offsets[mask])
            stdv[i] = np.std(offsets[mask])
        else:
            shape_profile[i] = np.nan
            stdv[i] = np.nan
    if plotting:
        bundle_shape_profile(x, shape_profile, stdv)
    return shape_profile, stdv