import numpy as np
from dipy.segment.mrf import ConstantObservationModel, IteratedConditionalModes
from dipy.sims.voxel import add_noise
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package
sklearn, has_sklearn, _ = optional_package("sklearn")
linear_model, _, _ = optional_package("sklearn.linear_model")
[docs]
class TissueClassifierHMRF:
    """
    This class contains the methods for tissue classification using the
    Markov Random Fields modeling approach.
    """
    @warning_for_keywords()
    def __init__(self, *, save_history=False, verbose=True):
        self.save_history = save_history
        self.segmentations = []
        self.pves = []
        self.energies = []
        self.energies_sum = []
        self.verbose = verbose
[docs]
    @warning_for_keywords()
    def classify(self, image, nclasses, beta, *, tolerance=1e-05, max_iter=100):
        """
        This method uses the Maximum a posteriori - Markov Random Field
        approach for segmentation by using the Iterative Conditional Modes
        and Expectation Maximization to estimate the parameters.
        Parameters
        ----------
        image : ndarray,
            3D structural image.
        nclasses : int,
            Number of desired classes.
        beta : float,
            Smoothing parameter, the higher this number the smoother the
            output will be.
        tolerance: float, optional
            Value that defines the percentage of change tolerated to
            prevent the ICM loop to stop. Default is 1e-05.
            If you want tolerance check to be disabled put 'tolerance = 0'.
        max_iter : int, optional
            Fixed number of desired iterations. Default is 100.
            This parameter defines the maximum number of iterations the
            algorithm will perform. The loop may terminate early if the
            change in energy sum between iterations falls below the
            threshold defined by `tolerance`. However, if `tolerance` is
            explicitly set to 0, this early stopping mechanism is disabled,
            and the algorithm will run for the specified number of
            iterations unless another stopping criterion is met.
        Returns
        -------
        initial_segmentation : ndarray,
            3D segmented image with all tissue types specified in nclasses.
        final_segmentation : ndarray,
            3D final refined segmentation containing all tissue types.
        PVE : ndarray,
            3D probability map of each tissue type.
        """
        nclasses += 1  # One extra class for the background
        energy_sum = [1e-05]
        com = ConstantObservationModel()
        icm = IteratedConditionalModes()
        if image.max() > 1:
            image = np.interp(image, [0, image.max()], [0.0, 1.0])
        mu, sigmasq = com.initialize_param_uniform(image, nclasses)
        p = np.argsort(mu)
        mu = mu[p]
        sigmasq = sigmasq[p]
        neglogl = com.negloglikelihood(image, mu, sigmasq, nclasses)
        seg_init = icm.initialize_maximum_likelihood(neglogl)
        mu, sigmasq = com.seg_stats(image, seg_init, nclasses)
        zero = np.zeros_like(image) + 0.001
        zero_noise = add_noise(zero, 10000, 1, noise_type="gaussian")
        image_gauss = np.where(image == 0, zero_noise, image)
        final_segmentation = np.empty_like(image)
        initial_segmentation = seg_init
        for i in range(max_iter):
            if self.verbose:
                print(f">> Iteration: {i}")
            PLN = icm.prob_neighborhood(seg_init, beta, nclasses)
            PVE = com.prob_image(image_gauss, nclasses, mu, sigmasq, PLN)
            mu_upd, sigmasq_upd = com.update_param(image_gauss, PVE, mu, nclasses)
            ind = np.argsort(mu_upd)
            mu_upd = mu_upd[ind]
            sigmasq_upd = sigmasq_upd[ind]
            negll = com.negloglikelihood(image_gauss, mu_upd, sigmasq_upd, nclasses)
            final_segmentation, energy = icm.icm_ising(negll, beta, seg_init)
            energy_sum.append(energy[energy > -np.inf].sum())
            if self.save_history:
                self.segmentations.append(final_segmentation)
                self.pves.append(PVE)
                self.energies.append(energy)
                self.energies_sum.append(energy_sum[-1])
            if tolerance > 0 and i > 5:
                e_sum = np.asarray(energy_sum)
                tol = tolerance * (np.amax(e_sum) - np.amin(e_sum))
                e_end = e_sum[-5:]
                test_dist = np.abs(np.amax(e_end) - np.amin(e_end))
                if test_dist < tol:
                    break
            seg_init = final_segmentation
            mu = mu_upd
            sigmasq = sigmasq_upd
        PVE = PVE[..., 1:]
        return initial_segmentation, final_segmentation, PVE 
 
[docs]
def compute_directional_average(
    data,
    bvals,
    *,
    s0_map=None,
    masks=None,
    b0_mask=None,
    b0_threshold=50,
    low_signal_threshold=50,
):
    """
    Compute the mean signal for each unique b-value shell and fit a linear model.
    Parameters
    ----------
    data : ndarray
        The diffusion MRI data.
    bvals : ndarray
        The b-values corresponding to the diffusion data.
    s0_map : ndarray, optional
        Precomputed mean signal map for b=0 images.
    masks : ndarray, optional
        Precomputed masks for each unique b-value shell.
    b0_mask : ndarray, optional
        Precomputed mask for b=0 images.
    b0_threshold : float, optional
        The intensity threshold for a b=0 image.
    low_signal_threshold : float, optional
        The threshold below which a voxel is considered to have low signal.
    Returns
    -------
    P : float
        The slope of the linear model.
    V : float
        The intercept of the linear model.
    """
    if b0_mask is None:
        b0_mask = bvals < b0_threshold
    if masks is None:
        unique_bvals = np.unique(bvals)
        masks = bvals[:, np.newaxis] == unique_bvals[np.newaxis, 1:]
    if s0_map is None:
        s0_map = data[..., b0_mask].mean(axis=-1)
    if s0_map < low_signal_threshold:
        return 0, 0
    # Calculate the mean for each mask
    means = np.sum(data[:, np.newaxis] * masks, axis=0) / np.sum(masks, axis=0)
    # Normalize by s0, avoiding division by zero by adding 0.01 for stable division
    s_bvals = means / (s0_map[..., np.newaxis] + 0.01)
    # Avoid log(0) by adding 0.001 for stable linear regression fit
    s_bvals[s_bvals == 0] = 0.001
    s_log = y = np.log(s_bvals)
    xb = -np.log(np.arange(1, s_log.shape[-1] + 1))
    # Reshape xb for linear regression
    X = xb.reshape(-1, 1)
    # Fit linear model
    model = linear_model.LinearRegression()
    model.fit(X, y)
    P = model.coef_[0]
    V = model.intercept_
    return P, V 
[docs]
def dam_classifier(
    data, bvals, wm_threshold, *, b0_threshold=50, low_signal_threshold=50
):
    """Computes the P-map (fitting slope) on data to extract white and grey matter.
    See :footcite:p:`Cheng2020` for further details about the method.
    Parameters
    ----------
    data : ndarray
        The diffusion MRI data.
    bvals : ndarray
        The b-values corresponding to the diffusion data.
    wm_threshold : float
        The threshold below which a voxel is considered white matter.
    b0_threshold : float, optional
        The intensity threshold for a b=0 image.
    low_signal_threshold : float, optional
        The threshold below which a voxel is considered to have low signal.
    Returns
    -------
    wm_mask : ndarray
        A binary mask for white matter.
    gm_mask : ndarray
        A binary mask for grey matter.
    References
    ----------
    .. footbibliography::
    """
    # Precompute unique b-values, masks, and b=0 mask
    unique_bvals = np.unique(bvals)
    if len(unique_bvals) <= 2:
        raise ValueError("Insufficient unique b-values for fitting.")
    b0_mask = bvals < b0_threshold
    masks = bvals[:, np.newaxis] == unique_bvals[np.newaxis, 1:]
    # Precompute s0 (mean signal for b=0)
    s0_map = data[..., b0_mask].mean(axis=-1)
    # If the mean signal for b=0 is too low, set those voxels to 0 for both P and V
    valid_voxels = s0_map >= low_signal_threshold
    P_map = np.zeros(data.shape[:-1])
    for idx in range(data.shape[0] * data.shape[1] * data.shape[2]):
        i, j, k = np.unravel_index(idx, P_map.shape)
        if valid_voxels[i, j, k]:
            P, _ = compute_directional_average(
                data[i, j, k, :],
                bvals,
                masks=masks,
                b0_mask=b0_mask,
                s0_map=s0_map[i, j, k],
                low_signal_threshold=low_signal_threshold,
            )
            P_map[i, j, k] = P
    # Adding a small slope threshold for P_map to avoid 0 sloped background voxels
    wm_mask = (P_map <= wm_threshold) & (P_map > 0.01)
    # Grey matter has a higher P value than white matter
    gm_mask = P_map > wm_threshold
    return wm_mask, gm_mask