Source code for dipy.denoise.patch2self

import os
import tempfile
import time
from warnings import warn

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from tqdm import tqdm

from dipy.stats.sketching import count_sketch
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package

sklearn, has_sklearn, _ = optional_package("sklearn")
linear_model, _, _ = optional_package("sklearn.linear_model")


def _vol_split(train, vol_idx):
    """Split the 3D volumes into the train and test set.

    Parameters
    ----------
    train : numpy.ndarray
        Array of all 3D patches flattened out to be 2D.
    vol_idx : int
        The volume number that needs to be held out for training.

    Returns
    -------
    cur_x : numpy.ndarray of shape (nvolumes * patch_size) x (nvoxels)
        Array of patches corresponding to all volumes except the held out volume.
    y : numpy.ndarray of shape (patch_size) x (nvoxels)
        Array of patches corresponding to the volume that is used a target for
        denoising.

    """
    mask = np.zeros(train.shape[0], dtype=bool)
    mask[vol_idx] = True
    cur_x = train[~mask].reshape((train.shape[0] - 1) * train.shape[1], train.shape[2])
    y = train[vol_idx, train.shape[1] // 2, :]
    return cur_x, y


def _extract_3d_patches(arr, patch_radius):
    """Extract 3D patches from 4D DWI data.

    Parameters
    ----------
    arr : ndarray
        The 4D noisy DWI data to be denoised.
    patch_radius : int or array of shape (3,)
        The radius of the local patch to be taken around each voxel (in
        voxels).

    Returns
    -------
    all_patches : ndarray
        All 3D patches flattened out to be 2D corresponding to the each 3D
        volume of the 4D DWI data.

    """
    patch_radius = np.asarray(patch_radius, dtype=int)
    if patch_radius.size == 1:
        patch_radius = np.repeat(patch_radius, 3)
    elif patch_radius.size != 3:
        raise ValueError("patch_radius should have length 1 or 3")

    patch_size = 2 * patch_radius + 1
    dim = arr.shape[-1]

    # Calculate the shape of the output array
    output_shape = tuple(arr.shape[i] - 2 * patch_radius[i] for i in range(3))
    total_patches = np.prod(output_shape)

    patches = sliding_window_view(arr, tuple(patch_size) + (dim,))

    # Reshape and transpose the patches to match the original function's output shape
    all_patches = patches.reshape(total_patches, np.prod(patch_size), dim)
    all_patches = all_patches.transpose(2, 1, 0)

    return np.array(all_patches)


def _fit_denoising_model(train, vol_idx, model, alpha):
    """Fit a single 3D volume using a train and test phase.

    Parameters
    ----------
    train : ndarray
        Array of all 3D patches flattened out to be 2D.
    vol_idx : int
        The volume number that needs to be held out for training.
    model : str or sklearn.base.RegressorMixin
        This will determine the algorithm used to solve the set of linear
        equations underlying this model. If it is a string it needs to be
        one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
        it can be an object that inherits from
        `dipy.optimize.SKLearnLinearSolver` or an object with a similar
        interface from Scikit-Learn:
        `sklearn.linear_model.LinearRegression`,
        `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
        and other objects that inherit from `sklearn.base.RegressorMixin`.
    alpha : float
        Regularization parameter only for ridge and lasso regression models.
    version : int
        Version 1 or 3 of Patch2Self to use.

    Returns
    -------
    model_instance : fitted linear model object
        The fitted model instance if version is 3.
    cur_x : ndarray
        The patches corresponding to all volumes except the held out volume.

    """
    if isinstance(model, str):
        if model.lower() == "ols":
            model_instance = linear_model.LinearRegression(copy_X=False)
        elif model.lower() == "ridge":
            model_instance = linear_model.Ridge(copy_X=False, alpha=alpha)
        elif model.lower() == "lasso":
            model_instance = linear_model.Lasso(copy_X=False, max_iter=50, alpha=alpha)
        else:
            raise ValueError(
                f"Invalid model string: {model}. Should be 'ols', 'ridge', or 'lasso'."
            )
    elif isinstance(model, linear_model.BaseEstimator):
        model_instance = model
    else:
        raise ValueError(
            "Model should either be a string or \
                an instance of sklearn.linear_model BaseEstimator."
        )
    cur_x, y = _vol_split(train, vol_idx)
    model_instance.fit(cur_x.T, y.T)
    return model_instance, cur_x


[docs] def vol_denoise( data_dict, b0_idx, dwi_idx, model, alpha, b0_denoising, verbose, tmp_dir ): """Denoise a single 3D volume using train and test phase. Parameters ---------- data_dict : dict Dictionary containing the following: data_name : str The name of the memmap file containing the memmaped data. data_dtype : dtype The dtype of the data. data_shape : tuple The shape of the data. data_b0s : ndarray Array of all 3D patches flattened out to be 2D for b0 volumes. data_dwi : ndarray Array of all 3D patches flattened out to be 2D for dwi volumes. b0_idx : ndarray The indices of the b0 volumes. dwi_idx : ndarray The indices of the dwi volumes. model : sklearn.base.RegressorMixin This is the model that is initialized from the `_fit_denoising_model` function. alpha : float Regularization parameter only for ridge and lasso regression models. b0_denoising : bool Skips denoising b0 volumes if set to False. verbose : bool Show progress of Patch2Self and time taken. tmp_dir : str The directory to save the temporary files. Returns ------- denoised_arr.name : str The name of the memmap file containing the denoised array. denoised_arr.dtype : dtype The dtype of the denoised array. denoised_arr.shape : tuple The shape of the denoised array. """ data_shape = data_dict["data"][2] data_tmp = np.memmap( data_dict["data"][0], dtype=data_dict["data"][1], mode="r", shape=data_dict["data"][2], ).reshape(np.prod(data_shape[:-1]), data_shape[-1]) data_b0s = data_dict["data_b0s"] data_dwi = data_dict["data_dwi"] p = data_tmp.shape[0] // 10 b0_counter = 0 dwi_counter = 0 start_idx = 0 denoised_arr_file = tempfile.NamedTemporaryFile( delete=False, dir=tmp_dir, suffix="denoised_arr" ) denoised_arr_file.close() denoised_arr = np.memmap( denoised_arr_file.name, dtype=data_tmp.dtype, mode="w+", shape=data_shape ) idx_counter = 0 full_result = np.empty( (data_shape[0], data_shape[1], data_shape[2], data_shape[3] // 5) ) b0_idx = b0_idx dwi_idx = dwi_idx if data_b0s.shape[0] == 1 or not b0_denoising: if verbose: print("b0 denoising skipped....") for i in range(data_b0s.shape[0]): full_result[..., i] = data_tmp[..., b0_counter].reshape( data_shape[0], data_shape[1], data_shape[2] ) b0_counter += 1 idx_counter += 1 for vol_idx in tqdm( range(data_shape[-1]), desc="Fitting and Denoising", leave=False ): if vol_idx in b0_idx.flatten(): if b0_denoising: b_fit, _ = _fit_denoising_model(data_b0s, b0_counter, model, alpha) b_matrix = np.zeros(data_tmp.shape[-1]) b_fit_coef = np.insert(b_fit.coef_, b0_counter, 0) np.put(b_matrix, b0_idx, b_fit_coef) result = np.zeros(data_tmp.shape[0]) for z in range(0, data_tmp.shape[0], p): end_idx = z + p if end_idx > z + p: end_idx = data_tmp.shape[0] result[z:end_idx] = ( np.matmul(np.squeeze(data_tmp[z:end_idx, :]), b_matrix) + b_fit.intercept_ ) full_result[..., idx_counter] = result.reshape( data_shape[0], data_shape[1], data_shape[2] ) idx_counter += 1 b0_counter += 1 del b_fit_coef del b_matrix del result else: dwi_fit, _ = _fit_denoising_model(data_dwi, dwi_counter, model, alpha) b_matrix = np.zeros(data_tmp.shape[-1]) dwi_fit_coef = np.insert(dwi_fit.coef_, dwi_counter, 0) np.put(b_matrix, dwi_idx, dwi_fit_coef) del dwi_fit_coef result = np.zeros(data_tmp.shape[0]) for z in range(0, data_tmp.shape[0], p): end_idx = z + p if end_idx > z + p: end_idx = data_tmp.shape[0] result[z:end_idx] = ( np.matmul(np.squeeze(data_tmp[z:end_idx, :]), b_matrix) + dwi_fit.intercept_ ) full_result[..., idx_counter] = result.reshape( data_shape[0], data_shape[1], data_shape[2] ) idx_counter += 1 dwi_counter += 1 if idx_counter >= data_shape[-1] // 5: denoised_arr[..., start_idx : vol_idx + 1] = full_result start_idx = vol_idx + 1 idx_counter = 0 denoised_arr_idx = data_shape[-1] - data_shape[-1] % 5 full_result_idx = full_result.shape[-1] - data_shape[-1] % 5 denoised_arr[..., denoised_arr_idx:] = full_result[..., full_result_idx:] del full_result return denoised_arr_file.name, denoised_arr.dtype, denoised_arr.shape
[docs] @warning_for_keywords() def patch2self( data, bvals, *, patch_radius=(0, 0, 0), model="ols", b0_threshold=50, out_dtype=None, alpha=1.0, verbose=False, b0_denoising=True, clip_negative_vals=False, shift_intensity=True, tmp_dir=None, version=3, ): """Patch2Self Denoiser. See :footcite:p:`Fadnavis2020` for further details about the method. See :footcite:p:`Fadnavis2024` for further details about the new method. Parameters ---------- data : ndarray The 4D noisy DWI data to be denoised. bvals : array of shape (N,) Array of the bvals from the DWI acquisition patch_radius : int or array of shape (3,), optional The radius of the local patch to be taken around each voxel (in voxels). Default: 0 (denoise in blocks of 1x1x1 voxels). model : string, or sklearn.base.RegressorMixin This will determine the algorithm used to solve the set of linear equations underlying this model. If it is a string it needs to be one of the following: {'ols', 'ridge', 'lasso'}. Otherwise, it can be an object that inherits from `dipy.optimize.SKLearnLinearSolver` or an object with a similar interface from Scikit-Learn: `sklearn.linear_model.LinearRegression`, `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge` and other objects that inherit from `sklearn.base.RegressorMixin`. b0_threshold : int, optional Threshold for considering volumes as b0. out_dtype : str or dtype, optional The dtype for the output array. Default: output has the same dtype as the input. alpha : float, optional Regularization parameter only for ridge regression model. verbose : bool, optional Show progress of Patch2Self and time taken. b0_denoising : bool, optional Skips denoising b0 volumes if set to False. clip_negative_vals : bool, optional Sets negative values after denoising to 0 using `np.clip`. shift_intensity : bool, optional Shifts the distribution of intensities per volume to give non-negative values. tmp_dir : str, optional The directory to save the temporary files. If None, the temporary files are saved in the system's default temporary directory. Default: None. version : int, optional Version 1 or 3 of Patch2Self to use. Default: 3 Returns ------- denoised array : ndarray This is the denoised array of the same size as that of the input data, clipped to non-negative values. References ---------- .. footbibliography:: """ out_dtype, tmp_dir, patch_radius = _validate_inputs( data, out_dtype, patch_radius, version, tmp_dir ) if version == 1: return _patch2self_version1( data, bvals, patch_radius, model, b0_threshold, out_dtype, alpha, verbose, b0_denoising, clip_negative_vals, shift_intensity, ) return _patch2self_version3( data, bvals, model, b0_threshold, out_dtype, alpha, verbose, b0_denoising, clip_negative_vals, shift_intensity, tmp_dir, )
def _validate_inputs(data, out_dtype, patch_radius, version, tmp_dir): """Validate inputs for patch2self function. Parameters ---------- data : ndarray The 4D noisy DWI data to be denoised. out_dtype : str or dtype The dtype for the output array. patch_radius : int or array of shape (3,) The radius of the local patch to be taken around each voxel (in voxels). version : int Version 1 or 3 of Patch2Self to use. tmp_dir : str The directory to save the temporary files. If None, the temporary files are saved in the system's default temporary directory. Raises ------ ValueError If temporary directory is not None for Patch2Self version 1. If the patch_radius is not 0 for Patch2Self version 3. If the temporary directory does not exist. If the input data is not a 4D array. Warns ----- If the input data has less than 10 3D volumes. Returns ------- out_dtype : str or dtype The dtype for the output array. tmp_dir : str The directory to save the temporary files. If None, the temporary files are saved in the system's default temporary directory. """ if out_dtype is None: out_dtype = data.dtype if tmp_dir is None and version == 3: tmp_dir = tempfile.gettempdir() if version not in [1, 3]: raise ValueError("Invalid version. Should be 1 or 3.") if version == 1 and tmp_dir is not None: raise ValueError( "Temporary directory is not supported for Patch2Self version 1. \ Please set tmp_dir to None." ) if patch_radius != (0, 0, 0) and version == 3: raise ValueError( "Patch radius is not supported for Patch2Self version 3. \ Please do not set patch_radius." ) if isinstance(patch_radius, list) and len(patch_radius) == 1: patch_radius = (patch_radius[0], patch_radius[0], patch_radius[0]) if isinstance(patch_radius, int): patch_radius = (patch_radius, patch_radius, patch_radius) if version == 3 and tmp_dir is not None and not os.path.exists(tmp_dir): raise ValueError("The temporary directory does not exist.") if data.ndim != 4: raise ValueError("Patch2Self can only denoise on 4D arrays.", data.shape) if data.shape[3] < 10: warn( "The input data has less than 10 3D volumes. \ Patch2Self may not give optimal denoising performance.", stacklevel=2, ) return out_dtype, tmp_dir, patch_radius def _patch2self_version1( data, bvals, patch_radius, model, b0_threshold, out_dtype, alpha, verbose, b0_denoising, clip_negative_vals, shift_intensity, ): """Patch2Self Denoiser. Parameters ---------- data : ndarray The 4D noisy DWI data to be denoised. bvals : array of shape (N,) Array of the bvals from the DWI acquisition. patch_radius : int or array of shape (3,) The radius of the local patch to be taken around each voxel (in voxels). model : string, or sklearn.base.RegressorMixin This will determine the algorithm used to solve the set of linear equations underlying this model. If it is a string it needs to be one of the following: {'ols', 'ridge', 'lasso'}. Otherwise, it can be an object that inherits from `dipy.optimize.SKLearnLinearSolver` or an object with a similar interface from Scikit-Learn: `sklearn.linear_model.LinearRegression`, `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge` and other objects that inherit from `sklearn.base.RegressorMixin`. b0_threshold : int Threshold for considering volumes as b0. out_dtype : str or dtype The dtype for the output array. Default: output has the same dtype as the input. alpha : float Regularization parameter only for ridge regression model. verbose : bool Show progress of Patch2Self and time taken. b0_denoising : bool Skips denoising b0 volumes if set to False. clip_negative_vals : bool Sets negative values after denoising to 0 using `np.clip`. shift_intensity : bool Shifts the distribution of intensities per volume to give non-negative values. Returns ------- denoised array : ndarray This is the denoised array of the same size as that of the input data, clipped to non-negative values. """ # We retain float64 precision, iff the input is in this precision: if data.dtype == np.float64: calc_dtype = np.float64 # Otherwise, we'll calculate things in float32 (saving memory) else: calc_dtype = np.float32 original_shape = data.shape if 1 in data.shape and data.shape[-1] != 1: position = data.shape.index(1) data = np.concatenate((data, data, data), position) # Segregates volumes by b0 threshold b0_idx = np.argwhere(bvals <= b0_threshold) dwi_idx = np.argwhere(bvals > b0_threshold) data_b0s = np.squeeze(np.take(data, b0_idx, axis=3)) data_dwi = np.squeeze(np.take(data, dwi_idx, axis=3)) # create empty arrays denoised_b0s = np.empty(data_b0s.shape, dtype=calc_dtype) denoised_dwi = np.empty(data_dwi.shape, dtype=calc_dtype) denoised_arr = np.empty(data.shape, dtype=calc_dtype) if verbose is True: t1 = time.time() # if only 1 b0 volume, skip denoising it if data_b0s.ndim == 3 or not b0_denoising: if verbose: print("b0 denoising skipped...") denoised_b0s = data_b0s else: train_b0 = _extract_3d_patches( np.pad( data_b0s, ( (patch_radius[0], patch_radius[0]), (patch_radius[1], patch_radius[1]), (patch_radius[2], patch_radius[2]), (0, 0), ), mode="constant", ), patch_radius=patch_radius, ) for vol_idx in range(0, data_b0s.shape[3]): b0_model, cur_x = _fit_denoising_model( train_b0, vol_idx, model, alpha=alpha ) denoised_b0s[..., vol_idx] = b0_model.predict(cur_x.T).reshape( data_b0s.shape[0], data_b0s.shape[1], data_b0s.shape[2] ) if verbose is True: print("Denoised b0 Volume: ", vol_idx) # Separate denoising for DWI volumes train_dwi = _extract_3d_patches( np.pad( data_dwi, ( (patch_radius[0], patch_radius[0]), (patch_radius[1], patch_radius[1]), (patch_radius[2], patch_radius[2]), (0, 0), ), mode="constant", ), patch_radius=patch_radius, ) # Insert the separately denoised arrays into the respective empty arrays for vol_idx in range(0, data_dwi.shape[3]): dwi_model, cur_x = _fit_denoising_model(train_dwi, vol_idx, model, alpha=alpha) denoised_dwi[..., vol_idx] = dwi_model.predict(cur_x.T).reshape( data_dwi.shape[0], data_dwi.shape[1], data_dwi.shape[2] ) if verbose is True: print("Denoised DWI Volume: ", vol_idx) if verbose is True: t2 = time.time() print("Total time taken for Patch2Self: ", t2 - t1, " seconds") if data_b0s.ndim == 3: denoised_arr[:, :, :, b0_idx[0][0]] = denoised_b0s else: for i, idx in enumerate(b0_idx): denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_b0s[..., i]) for i, idx in enumerate(dwi_idx): denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_dwi[..., i]) if 1 in original_shape and original_shape[-1] != 1: denoised_arr = np.take(denoised_arr, [0], axis=position) denoised_arr = _apply_post_processing( denoised_arr, shift_intensity, clip_negative_vals ) return np.array(denoised_arr, dtype=out_dtype) def _patch2self_version3( data, bvals, model, b0_threshold, out_dtype, alpha, verbose, b0_denoising, clip_negative_vals, shift_intensity, tmp_dir, ): """Patch2Self Denoiser. Parameters ---------- data : ndarray The 4D noisy DWI data to be denoised. bvals : array of shape (N,) Array of the bvals from the DWI acquisition. model : string, or sklearn.base.RegressorMixin This will determine the algorithm used to solve the set of linear equations underlying this model. If it is a string it needs to be one of the following: {'ols', 'ridge', 'lasso'}. Otherwise, it can be an object that inherits from `dipy.optimize.SKLearnLinearSolver` or an object with a similar interface from Scikit-Learn: `sklearn.linear_model.LinearRegression`, `sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge` and other objects that inherit from `sklearn.base.RegressorMixin`. b0_threshold : int Threshold for considering volumes as b0. out_dtype : str or dtype The dtype for the output array. Default: output has the same dtype as the input. alpha : float Regularization parameter only for ridge regression model. verbose : bool Show progress of Patch2Self and time taken. b0_denoising : bool Skips denoising b0 volumes if set to False. clip_negative_vals : bool Sets negative values after denoising to 0 using `np.clip`. shift_intensity : bool Shifts the distribution of intensities per volume to give non-negative values. tmp_dir : str The directory to save the temporary files. If None, the temporary files are saved in the system's default temporary directory. Returns ------- denoised array : ndarray This is the denoised array of the same size as that of the input data, clipped to non-negative values. """ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=tmp_dir, suffix="tmp_file") tmp_file.close() tmp = np.memmap( tmp_file.name, dtype=data.dtype, mode="w+", shape=(data.shape[0], data.shape[1], data.shape[2], data.shape[3]), ) p = data.shape[-1] // 5 idx_start = 0 for z in range(0, data.shape[3], p): end_idx = z + p if end_idx > data.shape[3]: end_idx = data.shape[3] if verbose: print("Loading data from {} to {}".format(idx_start, end_idx)) tmp[..., idx_start:end_idx] = data[..., idx_start:end_idx] idx_start = end_idx sketch_rows = int(0.30 * data.shape[0] * data.shape[1] * data.shape[2]) sketched_matrix_name, sketched_matrix_dtype, sketched_matrix_shape = count_sketch( tmp_file.name, data.dtype, tmp.shape, sketch_rows=sketch_rows, tmp_dir=tmp_dir, ) sketched_matrix = np.memmap( sketched_matrix_name, dtype=sketched_matrix_dtype, mode="r", shape=sketched_matrix_shape, ).T if verbose: print("Sketching done.") b0_idx = np.argwhere(bvals <= b0_threshold) dwi_idx = np.argwhere(bvals > b0_threshold) data_b0s = np.take(np.squeeze(sketched_matrix), b0_idx, axis=0) data_dwi = np.take(np.squeeze(sketched_matrix), dwi_idx, axis=0) data_dict = { "data": [tmp_file.name, data.dtype, tmp.shape], "data_b0s": data_b0s, "data_dwi": data_dwi, } if verbose: t1 = time.time() del sketched_matrix os.unlink(sketched_matrix_name) denoised_arr_name, denoised_arr_dtype, denoised_arr_shape = vol_denoise( data_dict, b0_idx, dwi_idx, model, alpha, b0_denoising, verbose, tmp_dir=tmp_dir, ) denoised_arr = np.memmap( denoised_arr_name, dtype=denoised_arr_dtype, mode="r+", shape=denoised_arr_shape, ) if verbose: t2 = time.time() print("Time taken for Patch2Self: ", t2 - t1, " seconds.") denoised_arr = _apply_post_processing( denoised_arr, shift_intensity, clip_negative_vals ) del tmp os.unlink(data_dict["data"][0]) result = np.array(denoised_arr, dtype=out_dtype) del denoised_arr os.unlink(denoised_arr_name) return result def _apply_post_processing(denoised_arr, shift_intensity, clip_negative_vals): """Apply post-processing steps such as clipping and shifting intensities. Parameters ---------- denoised_arr : ndarray The denoised array. shift_intensity : bool Shifts the distribution of intensities per volume to give non-negative values. clip_negative_vals : bool Sets negative values after denoising to 0 using `np.clip`. Returns ------- denoised_arr : ndarray The denoised array with post-processing applied. """ if shift_intensity and not clip_negative_vals: for i in range(denoised_arr.shape[-1]): shift = np.min(denoised_arr[..., i]) - np.min(denoised_arr[..., i]) denoised_arr[..., i] += shift elif clip_negative_vals and not shift_intensity: denoised_arr.clip(min=0, out=denoised_arr) elif clip_negative_vals and shift_intensity: warn( "Both `clip_negative_vals` and `shift_intensity` cannot be True. \ Defaulting to `clip_negative_vals`...", stacklevel=2, ) denoised_arr.clip(min=0, out=denoised_arr) return denoised_arr