import os
import tempfile
import time
from warnings import warn
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from tqdm import tqdm
from dipy.stats.sketching import count_sketch
from dipy.testing.decorators import warning_for_keywords
from dipy.utils.optpkg import optional_package
sklearn, has_sklearn, _ = optional_package("sklearn")
linear_model, _, _ = optional_package("sklearn.linear_model")
def _vol_split(train, vol_idx):
"""Split the 3D volumes into the train and test set.
Parameters
----------
train : numpy.ndarray
Array of all 3D patches flattened out to be 2D.
vol_idx : int
The volume number that needs to be held out for training.
Returns
-------
cur_x : numpy.ndarray of shape (nvolumes * patch_size) x (nvoxels)
Array of patches corresponding to all volumes except the held out volume.
y : numpy.ndarray of shape (patch_size) x (nvoxels)
Array of patches corresponding to the volume that is used a target for
denoising.
"""
mask = np.zeros(train.shape[0], dtype=bool)
mask[vol_idx] = True
cur_x = train[~mask].reshape((train.shape[0] - 1) * train.shape[1], train.shape[2])
y = train[vol_idx, train.shape[1] // 2, :]
return cur_x, y
def _extract_3d_patches(arr, patch_radius):
"""Extract 3D patches from 4D DWI data.
Parameters
----------
arr : ndarray
The 4D noisy DWI data to be denoised.
patch_radius : int or array of shape (3,)
The radius of the local patch to be taken around each voxel (in
voxels).
Returns
-------
all_patches : ndarray
All 3D patches flattened out to be 2D corresponding to the each 3D
volume of the 4D DWI data.
"""
patch_radius = np.asarray(patch_radius, dtype=int)
if patch_radius.size == 1:
patch_radius = np.repeat(patch_radius, 3)
elif patch_radius.size != 3:
raise ValueError("patch_radius should have length 1 or 3")
patch_size = 2 * patch_radius + 1
dim = arr.shape[-1]
# Calculate the shape of the output array
output_shape = tuple(arr.shape[i] - 2 * patch_radius[i] for i in range(3))
total_patches = np.prod(output_shape)
patches = sliding_window_view(arr, tuple(patch_size) + (dim,))
# Reshape and transpose the patches to match the original function's output shape
all_patches = patches.reshape(total_patches, np.prod(patch_size), dim)
all_patches = all_patches.transpose(2, 1, 0)
return np.array(all_patches)
def _fit_denoising_model(train, vol_idx, model, alpha):
"""Fit a single 3D volume using a train and test phase.
Parameters
----------
train : ndarray
Array of all 3D patches flattened out to be 2D.
vol_idx : int
The volume number that needs to be held out for training.
model : str or sklearn.base.RegressorMixin
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
it can be an object that inherits from
`dipy.optimize.SKLearnLinearSolver` or an object with a similar
interface from Scikit-Learn:
`sklearn.linear_model.LinearRegression`,
`sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
and other objects that inherit from `sklearn.base.RegressorMixin`.
alpha : float
Regularization parameter only for ridge and lasso regression models.
version : int
Version 1 or 3 of Patch2Self to use.
Returns
-------
model_instance : fitted linear model object
The fitted model instance if version is 3.
cur_x : ndarray
The patches corresponding to all volumes except the held out volume.
"""
if isinstance(model, str):
if model.lower() == "ols":
model_instance = linear_model.LinearRegression(copy_X=False)
elif model.lower() == "ridge":
model_instance = linear_model.Ridge(copy_X=False, alpha=alpha)
elif model.lower() == "lasso":
model_instance = linear_model.Lasso(copy_X=False, max_iter=50, alpha=alpha)
else:
raise ValueError(
f"Invalid model string: {model}. Should be 'ols', 'ridge', or 'lasso'."
)
elif isinstance(model, linear_model.BaseEstimator):
model_instance = model
else:
raise ValueError(
"Model should either be a string or \
an instance of sklearn.linear_model BaseEstimator."
)
cur_x, y = _vol_split(train, vol_idx)
model_instance.fit(cur_x.T, y.T)
return model_instance, cur_x
[docs]
def vol_denoise(
data_dict, b0_idx, dwi_idx, model, alpha, b0_denoising, verbose, tmp_dir
):
"""Denoise a single 3D volume using train and test phase.
Parameters
----------
data_dict : dict
Dictionary containing the following:
data_name : str
The name of the memmap file containing the memmaped data.
data_dtype : dtype
The dtype of the data.
data_shape : tuple
The shape of the data.
data_b0s : ndarray
Array of all 3D patches flattened out to be 2D for b0 volumes.
data_dwi : ndarray
Array of all 3D patches flattened out to be 2D for dwi volumes.
b0_idx : ndarray
The indices of the b0 volumes.
dwi_idx : ndarray
The indices of the dwi volumes.
model : sklearn.base.RegressorMixin
This is the model that is initialized from the `_fit_denoising_model` function.
alpha : float
Regularization parameter only for ridge and lasso regression models.
b0_denoising : bool
Skips denoising b0 volumes if set to False.
verbose : bool
Show progress of Patch2Self and time taken.
tmp_dir : str
The directory to save the temporary files.
Returns
-------
denoised_arr.name : str
The name of the memmap file containing the denoised array.
denoised_arr.dtype : dtype
The dtype of the denoised array.
denoised_arr.shape : tuple
The shape of the denoised array.
"""
data_shape = data_dict["data"][2]
data_tmp = np.memmap(
data_dict["data"][0],
dtype=data_dict["data"][1],
mode="r",
shape=data_dict["data"][2],
).reshape(np.prod(data_shape[:-1]), data_shape[-1])
data_b0s = data_dict["data_b0s"]
data_dwi = data_dict["data_dwi"]
p = data_tmp.shape[0] // 10
b0_counter = 0
dwi_counter = 0
start_idx = 0
denoised_arr_file = tempfile.NamedTemporaryFile(
delete=False, dir=tmp_dir, suffix="denoised_arr"
)
denoised_arr_file.close()
denoised_arr = np.memmap(
denoised_arr_file.name, dtype=data_tmp.dtype, mode="w+", shape=data_shape
)
idx_counter = 0
full_result = np.empty(
(data_shape[0], data_shape[1], data_shape[2], data_shape[3] // 5)
)
b0_idx = b0_idx
dwi_idx = dwi_idx
if data_b0s.shape[0] == 1 or not b0_denoising:
if verbose:
print("b0 denoising skipped....")
for i in range(data_b0s.shape[0]):
full_result[..., i] = data_tmp[..., b0_counter].reshape(
data_shape[0], data_shape[1], data_shape[2]
)
b0_counter += 1
idx_counter += 1
for vol_idx in tqdm(
range(data_shape[-1]), desc="Fitting and Denoising", leave=False
):
if vol_idx in b0_idx.flatten():
if b0_denoising:
b_fit, _ = _fit_denoising_model(data_b0s, b0_counter, model, alpha)
b_matrix = np.zeros(data_tmp.shape[-1])
b_fit_coef = np.insert(b_fit.coef_, b0_counter, 0)
np.put(b_matrix, b0_idx, b_fit_coef)
result = np.zeros(data_tmp.shape[0])
for z in range(0, data_tmp.shape[0], p):
end_idx = z + p
if end_idx > z + p:
end_idx = data_tmp.shape[0]
result[z:end_idx] = (
np.matmul(np.squeeze(data_tmp[z:end_idx, :]), b_matrix)
+ b_fit.intercept_
)
full_result[..., idx_counter] = result.reshape(
data_shape[0], data_shape[1], data_shape[2]
)
idx_counter += 1
b0_counter += 1
del b_fit_coef
del b_matrix
del result
else:
dwi_fit, _ = _fit_denoising_model(data_dwi, dwi_counter, model, alpha)
b_matrix = np.zeros(data_tmp.shape[-1])
dwi_fit_coef = np.insert(dwi_fit.coef_, dwi_counter, 0)
np.put(b_matrix, dwi_idx, dwi_fit_coef)
del dwi_fit_coef
result = np.zeros(data_tmp.shape[0])
for z in range(0, data_tmp.shape[0], p):
end_idx = z + p
if end_idx > z + p:
end_idx = data_tmp.shape[0]
result[z:end_idx] = (
np.matmul(np.squeeze(data_tmp[z:end_idx, :]), b_matrix)
+ dwi_fit.intercept_
)
full_result[..., idx_counter] = result.reshape(
data_shape[0], data_shape[1], data_shape[2]
)
idx_counter += 1
dwi_counter += 1
if idx_counter >= data_shape[-1] // 5:
denoised_arr[..., start_idx : vol_idx + 1] = full_result
start_idx = vol_idx + 1
idx_counter = 0
denoised_arr_idx = data_shape[-1] - data_shape[-1] % 5
full_result_idx = full_result.shape[-1] - data_shape[-1] % 5
denoised_arr[..., denoised_arr_idx:] = full_result[..., full_result_idx:]
del full_result
return denoised_arr_file.name, denoised_arr.dtype, denoised_arr.shape
[docs]
@warning_for_keywords()
def patch2self(
data,
bvals,
*,
patch_radius=(0, 0, 0),
model="ols",
b0_threshold=50,
out_dtype=None,
alpha=1.0,
verbose=False,
b0_denoising=True,
clip_negative_vals=False,
shift_intensity=True,
tmp_dir=None,
version=3,
):
"""Patch2Self Denoiser.
See :footcite:p:`Fadnavis2020` for further details about the method.
See :footcite:p:`Fadnavis2024` for further details about the new method.
Parameters
----------
data : ndarray
The 4D noisy DWI data to be denoised.
bvals : array of shape (N,)
Array of the bvals from the DWI acquisition
patch_radius : int or array of shape (3,), optional
The radius of the local patch to be taken around each voxel (in
voxels). Default: 0 (denoise in blocks of 1x1x1 voxels).
model : string, or sklearn.base.RegressorMixin
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
it can be an object that inherits from
`dipy.optimize.SKLearnLinearSolver` or an object with a similar
interface from Scikit-Learn:
`sklearn.linear_model.LinearRegression`,
`sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
and other objects that inherit from `sklearn.base.RegressorMixin`.
b0_threshold : int, optional
Threshold for considering volumes as b0.
out_dtype : str or dtype, optional
The dtype for the output array. Default: output has the same dtype as
the input.
alpha : float, optional
Regularization parameter only for ridge regression model.
verbose : bool, optional
Show progress of Patch2Self and time taken.
b0_denoising : bool, optional
Skips denoising b0 volumes if set to False.
clip_negative_vals : bool, optional
Sets negative values after denoising to 0 using `np.clip`.
shift_intensity : bool, optional
Shifts the distribution of intensities per volume to give
non-negative values.
tmp_dir : str, optional
The directory to save the temporary files. If None, the temporary
files are saved in the system's default temporary directory. Default: None.
version : int, optional
Version 1 or 3 of Patch2Self to use. Default: 3
Returns
-------
denoised array : ndarray
This is the denoised array of the same size as that of the input data,
clipped to non-negative values.
References
----------
.. footbibliography::
"""
out_dtype, tmp_dir, patch_radius = _validate_inputs(
data, out_dtype, patch_radius, version, tmp_dir
)
if version == 1:
return _patch2self_version1(
data,
bvals,
patch_radius,
model,
b0_threshold,
out_dtype,
alpha,
verbose,
b0_denoising,
clip_negative_vals,
shift_intensity,
)
return _patch2self_version3(
data,
bvals,
model,
b0_threshold,
out_dtype,
alpha,
verbose,
b0_denoising,
clip_negative_vals,
shift_intensity,
tmp_dir,
)
def _validate_inputs(data, out_dtype, patch_radius, version, tmp_dir):
"""Validate inputs for patch2self function.
Parameters
----------
data : ndarray
The 4D noisy DWI data to be denoised.
out_dtype : str or dtype
The dtype for the output array.
patch_radius : int or array of shape (3,)
The radius of the local patch to be taken around each voxel (in
voxels).
version : int
Version 1 or 3 of Patch2Self to use.
tmp_dir : str
The directory to save the temporary files. If None, the temporary
files are saved in the system's default temporary directory.
Raises
------
ValueError
If temporary directory is not None for Patch2Self version 1.
If the patch_radius is not 0 for Patch2Self version 3.
If the temporary directory does not exist.
If the input data is not a 4D array.
Warns
-----
If the input data has less than 10 3D volumes.
Returns
-------
out_dtype : str or dtype
The dtype for the output array.
tmp_dir : str
The directory to save the temporary files. If None, the temporary
files are saved in the system's default temporary directory.
"""
if out_dtype is None:
out_dtype = data.dtype
if tmp_dir is None and version == 3:
tmp_dir = tempfile.gettempdir()
if version not in [1, 3]:
raise ValueError("Invalid version. Should be 1 or 3.")
if version == 1 and tmp_dir is not None:
raise ValueError(
"Temporary directory is not supported for Patch2Self version 1. \
Please set tmp_dir to None."
)
if patch_radius != (0, 0, 0) and version == 3:
raise ValueError(
"Patch radius is not supported for Patch2Self version 3. \
Please do not set patch_radius."
)
if isinstance(patch_radius, list) and len(patch_radius) == 1:
patch_radius = (patch_radius[0], patch_radius[0], patch_radius[0])
if isinstance(patch_radius, int):
patch_radius = (patch_radius, patch_radius, patch_radius)
if version == 3 and tmp_dir is not None and not os.path.exists(tmp_dir):
raise ValueError("The temporary directory does not exist.")
if data.ndim != 4:
raise ValueError("Patch2Self can only denoise on 4D arrays.", data.shape)
if data.shape[3] < 10:
warn(
"The input data has less than 10 3D volumes. \
Patch2Self may not give optimal denoising performance.",
stacklevel=2,
)
return out_dtype, tmp_dir, patch_radius
def _patch2self_version1(
data,
bvals,
patch_radius,
model,
b0_threshold,
out_dtype,
alpha,
verbose,
b0_denoising,
clip_negative_vals,
shift_intensity,
):
"""Patch2Self Denoiser.
Parameters
----------
data : ndarray
The 4D noisy DWI data to be denoised.
bvals : array of shape (N,)
Array of the bvals from the DWI acquisition.
patch_radius : int or array of shape (3,)
The radius of the local patch to be taken around each voxel (in
voxels).
model : string, or sklearn.base.RegressorMixin
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
it can be an object that inherits from
`dipy.optimize.SKLearnLinearSolver` or an object with a similar
interface from Scikit-Learn:
`sklearn.linear_model.LinearRegression`,
`sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
and other objects that inherit from `sklearn.base.RegressorMixin`.
b0_threshold : int
Threshold for considering volumes as b0.
out_dtype : str or dtype
The dtype for the output array. Default: output has the same dtype as
the input.
alpha : float
Regularization parameter only for ridge regression model.
verbose : bool
Show progress of Patch2Self and time taken.
b0_denoising : bool
Skips denoising b0 volumes if set to False.
clip_negative_vals : bool
Sets negative values after denoising to 0 using `np.clip`.
shift_intensity : bool
Shifts the distribution of intensities per volume to give
non-negative values.
Returns
-------
denoised array : ndarray
This is the denoised array of the same size as that of the input data,
clipped to non-negative values.
"""
# We retain float64 precision, iff the input is in this precision:
if data.dtype == np.float64:
calc_dtype = np.float64
# Otherwise, we'll calculate things in float32 (saving memory)
else:
calc_dtype = np.float32
original_shape = data.shape
if 1 in data.shape and data.shape[-1] != 1:
position = data.shape.index(1)
data = np.concatenate((data, data, data), position)
# Segregates volumes by b0 threshold
b0_idx = np.argwhere(bvals <= b0_threshold)
dwi_idx = np.argwhere(bvals > b0_threshold)
data_b0s = np.squeeze(np.take(data, b0_idx, axis=3))
data_dwi = np.squeeze(np.take(data, dwi_idx, axis=3))
# create empty arrays
denoised_b0s = np.empty(data_b0s.shape, dtype=calc_dtype)
denoised_dwi = np.empty(data_dwi.shape, dtype=calc_dtype)
denoised_arr = np.empty(data.shape, dtype=calc_dtype)
if verbose is True:
t1 = time.time()
# if only 1 b0 volume, skip denoising it
if data_b0s.ndim == 3 or not b0_denoising:
if verbose:
print("b0 denoising skipped...")
denoised_b0s = data_b0s
else:
train_b0 = _extract_3d_patches(
np.pad(
data_b0s,
(
(patch_radius[0], patch_radius[0]),
(patch_radius[1], patch_radius[1]),
(patch_radius[2], patch_radius[2]),
(0, 0),
),
mode="constant",
),
patch_radius=patch_radius,
)
for vol_idx in range(0, data_b0s.shape[3]):
b0_model, cur_x = _fit_denoising_model(
train_b0, vol_idx, model, alpha=alpha
)
denoised_b0s[..., vol_idx] = b0_model.predict(cur_x.T).reshape(
data_b0s.shape[0], data_b0s.shape[1], data_b0s.shape[2]
)
if verbose is True:
print("Denoised b0 Volume: ", vol_idx)
# Separate denoising for DWI volumes
train_dwi = _extract_3d_patches(
np.pad(
data_dwi,
(
(patch_radius[0], patch_radius[0]),
(patch_radius[1], patch_radius[1]),
(patch_radius[2], patch_radius[2]),
(0, 0),
),
mode="constant",
),
patch_radius=patch_radius,
)
# Insert the separately denoised arrays into the respective empty arrays
for vol_idx in range(0, data_dwi.shape[3]):
dwi_model, cur_x = _fit_denoising_model(train_dwi, vol_idx, model, alpha=alpha)
denoised_dwi[..., vol_idx] = dwi_model.predict(cur_x.T).reshape(
data_dwi.shape[0], data_dwi.shape[1], data_dwi.shape[2]
)
if verbose is True:
print("Denoised DWI Volume: ", vol_idx)
if verbose is True:
t2 = time.time()
print("Total time taken for Patch2Self: ", t2 - t1, " seconds")
if data_b0s.ndim == 3:
denoised_arr[:, :, :, b0_idx[0][0]] = denoised_b0s
else:
for i, idx in enumerate(b0_idx):
denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_b0s[..., i])
for i, idx in enumerate(dwi_idx):
denoised_arr[:, :, :, idx[0]] = np.squeeze(denoised_dwi[..., i])
if 1 in original_shape and original_shape[-1] != 1:
denoised_arr = np.take(denoised_arr, [0], axis=position)
denoised_arr = _apply_post_processing(
denoised_arr, shift_intensity, clip_negative_vals
)
return np.array(denoised_arr, dtype=out_dtype)
def _patch2self_version3(
data,
bvals,
model,
b0_threshold,
out_dtype,
alpha,
verbose,
b0_denoising,
clip_negative_vals,
shift_intensity,
tmp_dir,
):
"""Patch2Self Denoiser.
Parameters
----------
data : ndarray
The 4D noisy DWI data to be denoised.
bvals : array of shape (N,)
Array of the bvals from the DWI acquisition.
model : string, or sklearn.base.RegressorMixin
This will determine the algorithm used to solve the set of linear
equations underlying this model. If it is a string it needs to be
one of the following: {'ols', 'ridge', 'lasso'}. Otherwise,
it can be an object that inherits from
`dipy.optimize.SKLearnLinearSolver` or an object with a similar
interface from Scikit-Learn:
`sklearn.linear_model.LinearRegression`,
`sklearn.linear_model.Lasso` or `sklearn.linear_model.Ridge`
and other objects that inherit from `sklearn.base.RegressorMixin`.
b0_threshold : int
Threshold for considering volumes as b0.
out_dtype : str or dtype
The dtype for the output array. Default: output has the same dtype as
the input.
alpha : float
Regularization parameter only for ridge regression model.
verbose : bool
Show progress of Patch2Self and time taken.
b0_denoising : bool
Skips denoising b0 volumes if set to False.
clip_negative_vals : bool
Sets negative values after denoising to 0 using `np.clip`.
shift_intensity : bool
Shifts the distribution of intensities per volume to give
non-negative values.
tmp_dir : str
The directory to save the temporary files. If None, the temporary
files are saved in the system's default temporary directory.
Returns
-------
denoised array : ndarray
This is the denoised array of the same size as that of the input data,
clipped to non-negative values.
"""
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=tmp_dir, suffix="tmp_file")
tmp_file.close()
tmp = np.memmap(
tmp_file.name,
dtype=data.dtype,
mode="w+",
shape=(data.shape[0], data.shape[1], data.shape[2], data.shape[3]),
)
p = data.shape[-1] // 5
idx_start = 0
for z in range(0, data.shape[3], p):
end_idx = z + p
if end_idx > data.shape[3]:
end_idx = data.shape[3]
if verbose:
print("Loading data from {} to {}".format(idx_start, end_idx))
tmp[..., idx_start:end_idx] = data[..., idx_start:end_idx]
idx_start = end_idx
sketch_rows = int(0.30 * data.shape[0] * data.shape[1] * data.shape[2])
sketched_matrix_name, sketched_matrix_dtype, sketched_matrix_shape = count_sketch(
tmp_file.name,
data.dtype,
tmp.shape,
sketch_rows=sketch_rows,
tmp_dir=tmp_dir,
)
sketched_matrix = np.memmap(
sketched_matrix_name,
dtype=sketched_matrix_dtype,
mode="r",
shape=sketched_matrix_shape,
).T
if verbose:
print("Sketching done.")
b0_idx = np.argwhere(bvals <= b0_threshold)
dwi_idx = np.argwhere(bvals > b0_threshold)
data_b0s = np.take(np.squeeze(sketched_matrix), b0_idx, axis=0)
data_dwi = np.take(np.squeeze(sketched_matrix), dwi_idx, axis=0)
data_dict = {
"data": [tmp_file.name, data.dtype, tmp.shape],
"data_b0s": data_b0s,
"data_dwi": data_dwi,
}
if verbose:
t1 = time.time()
del sketched_matrix
os.unlink(sketched_matrix_name)
denoised_arr_name, denoised_arr_dtype, denoised_arr_shape = vol_denoise(
data_dict,
b0_idx,
dwi_idx,
model,
alpha,
b0_denoising,
verbose,
tmp_dir=tmp_dir,
)
denoised_arr = np.memmap(
denoised_arr_name,
dtype=denoised_arr_dtype,
mode="r+",
shape=denoised_arr_shape,
)
if verbose:
t2 = time.time()
print("Time taken for Patch2Self: ", t2 - t1, " seconds.")
denoised_arr = _apply_post_processing(
denoised_arr, shift_intensity, clip_negative_vals
)
del tmp
os.unlink(data_dict["data"][0])
result = np.array(denoised_arr, dtype=out_dtype)
del denoised_arr
os.unlink(denoised_arr_name)
return result
def _apply_post_processing(denoised_arr, shift_intensity, clip_negative_vals):
"""Apply post-processing steps such as clipping and shifting intensities.
Parameters
----------
denoised_arr : ndarray
The denoised array.
shift_intensity : bool
Shifts the distribution of intensities per volume to give
non-negative values.
clip_negative_vals : bool
Sets negative values after denoising to 0 using `np.clip`.
Returns
-------
denoised_arr : ndarray
The denoised array with post-processing applied.
"""
if shift_intensity and not clip_negative_vals:
for i in range(denoised_arr.shape[-1]):
shift = np.min(denoised_arr[..., i]) - np.min(denoised_arr[..., i])
denoised_arr[..., i] += shift
elif clip_negative_vals and not shift_intensity:
denoised_arr.clip(min=0, out=denoised_arr)
elif clip_negative_vals and shift_intensity:
warn(
"Both `clip_negative_vals` and `shift_intensity` cannot be True. \
Defaulting to `clip_negative_vals`...",
stacklevel=2,
)
denoised_arr.clip(min=0, out=denoised_arr)
return denoised_arr