Source code for dipy.denoise.gibbs

from functools import partial
import multiprocessing as mp

import numpy as np
import scipy.fft

from dipy.testing.decorators import warning_for_keywords
from dipy.utils.multiproc import determine_num_processes

_fft = scipy.fft


@warning_for_keywords()
def _image_tv(x, *, axis=0, n_points=3):
    """Computes total variation (TV) of matrix x across a given axis and
    along two directions.

    Parameters
    ----------
    x : 2D ndarray
        matrix x
    axis : int (0 or 1)
        Axis which TV will be calculated. Default a is set to 0.
    n_points : int
        Number of points to be included in TV calculation.

    Returns
    -------
    ptv : 2D ndarray
        Total variation calculated from the right neighbours of each point.
    ntv : 2D ndarray
        Total variation calculated from the left neighbours of each point.

    """
    xs = x.copy() if axis else x.T.copy()

    # Add copies of the data so that data extreme points are also analysed
    xs = np.concatenate(
        (xs[:, (-n_points - 1) :], xs, xs[:, 0 : (n_points + 1)]), axis=1
    )

    ptv = np.absolute(
        xs[:, (n_points + 1) : (-n_points - 1)] - xs[:, (n_points + 2) : (-n_points)]
    )
    ntv = np.absolute(
        xs[:, (n_points + 1) : (-n_points - 1)] - xs[:, n_points : (-n_points - 2)]
    )
    for n in range(1, n_points):
        ptv = ptv + np.absolute(
            xs[:, (n_points + 1 + n) : (-n_points - 1 + n)]
            - xs[:, (n_points + 2 + n) : (-n_points + n)]
        )
        ntv = ntv + np.absolute(
            xs[:, (n_points + 1 - n) : (-n_points - 1 - n)]
            - xs[:, (n_points - n) : (-n_points - 2 - n)]
        )

    if axis:
        return ptv, ntv
    else:
        return ptv.T, ntv.T


@warning_for_keywords()
def _gibbs_removal_1d(x, *, axis=0, n_points=3):
    """Suppresses Gibbs ringing along a given axis using fourier sub-shifts.

    Parameters
    ----------
    x : 2D ndarray
        Matrix x.
    axis : int (0 or 1)
        Axis in which Gibbs oscillations will be suppressed.
        Default is set to 0.
    n_points : int, optional
        Number of neighbours to access local TV (see note).
        Default is set to 3.

    Returns
    -------
    xc : 2D ndarray
        Matrix with suppressed Gibbs oscillations along the given axis.

    Notes
    -----
    This function suppresses the effects of Gibbs oscillations based on the
    analysis of local total variation (TV). Although artefact correction is
    done based on two adjacent points for each voxel, total variation should be
    accessed in a larger range of neighbours. The number of neighbours to be
    considered in TV calculation can be adjusted using the parameter n_points.

    """
    dtype_float = np.promote_types(x.real.dtype, np.float32)

    ssamp = np.linspace(0.02, 0.9, num=45, dtype=dtype_float)

    xs = x.copy() if axis else x.T.copy()

    # TV for shift zero (baseline)
    tvr, tvl = _image_tv(xs, axis=1, n_points=n_points)
    tvp = np.minimum(tvr, tvl)
    tvn = tvp.copy()

    # Find optimal shift for gibbs removal
    isp = xs.copy()
    isn = xs.copy()
    sp = np.zeros(xs.shape, dtype=dtype_float)
    sn = np.zeros(xs.shape, dtype=dtype_float)
    N = xs.shape[1]
    c = _fft.fft(xs, axis=1)
    k = _fft.fftfreq(N, 1 / (2.0j * np.pi))
    k = k.astype(c.dtype, copy=False)
    for s in ssamp:
        ks = k * s
        # Access positive shift for given s
        img_p = abs(_fft.ifft(c * np.exp(ks), axis=1))

        tvsr, tvsl = _image_tv(img_p, axis=1, n_points=n_points)
        tvs_p = np.minimum(tvsr, tvsl)

        # Access negative shift for given s
        img_n = abs(_fft.ifft(c * np.exp(-ks), axis=1))
        tvsr, tvsl = _image_tv(img_n, axis=1, n_points=n_points)
        tvs_n = np.minimum(tvsr, tvsl)

        # Update positive shift params
        isp[tvp > tvs_p] = img_p[tvp > tvs_p]
        sp[tvp > tvs_p] = s
        tvp[tvp > tvs_p] = tvs_p[tvp > tvs_p]

        # Update negative shift params
        isn[tvn > tvs_n] = img_n[tvn > tvs_n]
        sn[tvn > tvs_n] = s
        tvn[tvn > tvs_n] = tvs_n[tvn > tvs_n]

    # check non-zero sub-voxel shifts
    idx = np.nonzero(sp + sn)

    # use positive and negative optimal sub-voxel shifts to interpolate to
    # original grid points
    xs[idx] = (isp[idx] - isn[idx]) / (sp[idx] + sn[idx]) * sn[idx] + isn[idx]

    return xs if axis else xs.T


def _weights(shape):
    """Computes the weights necessary to combine two images processed by
    the 1D Gibbs removal procedure along two different axes.

    See :footcite:p:`Kellner2016` for further details about the method.

    Parameters
    ----------
    shape : tuple
        shape of the image.

    Returns
    -------
    G0 : 2D ndarray
        Weights for the image corrected along axis 0.
    G1 : 2D ndarray
        Weights for the image corrected along axis 1.

    References
    ----------
    .. footbibliography::

    """
    G0 = np.zeros(shape)
    G1 = np.zeros(shape)
    k0 = np.linspace(-np.pi, np.pi, num=shape[0])
    k1 = np.linspace(-np.pi, np.pi, num=shape[1])

    # Middle points
    K1, K0 = np.meshgrid(k1[1:-1], k0[1:-1])
    cosk0 = 1.0 + np.cos(K0)
    cosk1 = 1.0 + np.cos(K1)
    G1[1:-1, 1:-1] = cosk0 / (cosk0 + cosk1)
    G0[1:-1, 1:-1] = cosk1 / (cosk0 + cosk1)

    # Boundaries
    G1[1:-1, 0] = G1[1:-1, -1] = 1
    G1[0, 0] = G1[-1, -1] = G1[0, -1] = G1[-1, 0] = 1 / 2
    G0[0, 1:-1] = G0[-1, 1:-1] = 1
    G0[0, 0] = G0[-1, -1] = G0[0, -1] = G0[-1, 0] = 1 / 2

    return G0, G1


@warning_for_keywords()
def _gibbs_removal_2d(image, *, n_points=3, G0=None, G1=None):
    """Suppress Gibbs ringing of a 2D image :footcite:p:`Kellner2016`,
    :footcite:p:`NetoHenriques2018`.

    Parameters
    ----------
    image : 2D ndarray
        Matrix containing the 2D image.
    n_points : int, optional
        Number of neighbours to access local TV (see note). Default is
        set to 3.
    G0 : 2D ndarray, optional
        Weights for the image corrected along axis 0. If not given, the
        function estimates them using the function :func:`_weights`.
    G1 : 2D ndarray, optional
        Weights for the image corrected along axis 1. If not given, the
        function estimates them using the function :func:`_weights`.

    Returns
    -------
    imagec : 2D ndarray
        Matrix with Gibbs oscillations reduced along axis a.

    Notes
    -----
    This function suppresses the effects of Gibbs oscillations based on the
    analysis of local total variation (TV). Although artefact correction is
    done based on two adjacent points for each voxel, total variation should be
    accessed in a larger range of neighbours. The number of neighbours to be
    considered in TV calculation can be adjusted using the parameter n_points.

    References
    ----------
    .. footbibliography::

    """
    if G0 is None or G1 is None:
        G0, G1 = _weights(image.shape)

    img_c1 = _gibbs_removal_1d(image, axis=1, n_points=n_points)
    img_c0 = _gibbs_removal_1d(image, axis=0, n_points=n_points)

    C1 = _fft.fft2(img_c1)
    C0 = _fft.fft2(img_c0)
    imagec = abs(_fft.ifft2(_fft.fftshift(C1) * G1 + _fft.fftshift(C0) * G0))

    return imagec


[docs] @warning_for_keywords() def gibbs_removal(vol, *, slice_axis=2, n_points=3, inplace=True, num_processes=1): """Suppresses Gibbs ringing artefacts of images volumes. See :footcite:p:`Kellner2016` and :footcite:p:`NetoHenriques2018` for further details about the method. Parameters ---------- vol : ndarray ([X, Y]), ([X, Y, Z]) or ([X, Y, Z, g]) Matrix containing one volume (3D) or multiple (4D) volumes of images. slice_axis : int (0, 1, or 2) Data axis corresponding to the number of acquired slices. n_points : int, optional Number of neighbour points to access local TV (see note). inplace : bool, optional If True, the input data is replaced with results. Otherwise, returns a new array. num_processes : int or None, optional Split the calculation to a pool of children processes. This only applies to 3D or 4D `data` arrays. Default is 1. If < 0 the maximal number of cores minus ``num_processes + 1`` is used (enter -1 to use as many cores as possible). 0 raises an error. Returns ------- vol : ndarray ([X, Y]), ([X, Y, Z]) or ([X, Y, Z, g]) Matrix containing one volume (3D) or multiple (4D) volumes of corrected images. Notes ----- For 4D matrix last element should always correspond to the number of diffusion gradient directions. References ---------- .. footbibliography:: """ nd = vol.ndim # check matrix dimension if nd > 4: raise ValueError("Data have to be a 4D, 3D or 2D matrix") elif nd < 2: raise ValueError("Data is not an image") if not isinstance(inplace, bool): raise TypeError("inplace must be a boolean.") num_processes = determine_num_processes(num_processes) # check the axis corresponding to different slices # 1) This axis cannot be larger than 2 if slice_axis > 2: raise ValueError( "Different slices have to be organized along" + "one of the 3 first matrix dimensions" ) # 2) Reorder axis to allow iteration over the first axis elif nd == 3: vol = np.moveaxis(vol, slice_axis, 0) elif nd == 4: vol = np.moveaxis(vol, (slice_axis, 3), (0, 1)) if nd == 4: inishap = vol.shape vol = vol.reshape((inishap[0] * inishap[1], inishap[2], inishap[3])) # Produce weighting functions for 2D Gibbs removal shap = vol.shape G0, G1 = _weights(shap[-2:]) # Copy data if not inplace if not inplace: vol = vol.copy() # Run Gibbs removal of 2D images if nd == 2: vol[:, :] = _gibbs_removal_2d(vol, n_points=n_points, G0=G0, G1=G1) else: if num_processes == 1: for i in range(shap[0]): vol[i, :, :] = _gibbs_removal_2d( vol[i, :, :], n_points=n_points, G0=G0, G1=G1 ) else: mp.set_start_method("spawn", force=True) pool = mp.Pool(num_processes) partial_func = partial(_gibbs_removal_2d, n_points=n_points, G0=G0, G1=G1) vol[:, :, :] = pool.map(partial_func, vol) pool.close() pool.join() # Reshape data to original format if nd == 3: vol = np.moveaxis(vol, 0, slice_axis) if nd == 4: vol = vol.reshape(inishap) vol = np.moveaxis(vol, (0, 1), (slice_axis, 3)) return vol