"""Bias field correction for diffusion MRI data.
Provides classical regression-based bias field correction via Legendre
polynomial regression and cubic B-spline regression.
The bias field is estimated exclusively from the mean b0 volume in the
log domain and applied uniformly to all DWI volumes.
"""
import numpy as np
from scipy import linalg as scipy_linalg, ndimage, sparse
try:
from dipy.denoise._bias_correction import (
compute_tukey_weights,
evaluate_bspline_rows,
gram_matrix_csr,
masked_voxel_coords,
)
_HAVE_CYTHON = True
except ImportError:
_HAVE_CYTHON = False
from dipy.core.gradients import extract_b0
from dipy.segment.mask import applymask, median_otsu
from dipy.utils.logging import logger
try:
from dipy.align.vector_fields import gradient as _vf_gradient
_HAVE_VF_GRADIENT = True
except ImportError:
_HAVE_VF_GRADIENT = False
def _get_mean_b0(data, gtab):
"""Return mean b0 volume as float64.
Parameters
----------
data : ndarray
4D DWI data (X, Y, Z, N).
gtab : GradientTable
Gradient table with b0s_mask attribute.
Returns
-------
mean_b0 : ndarray
3D mean b0 volume, dtype float64.
"""
return extract_b0(data, gtab.b0s_mask, strategy="mean").astype(np.float64)
def _get_mask(mean_b0, mask):
"""Return binary brain mask, computing via median_otsu if not provided.
Parameters
----------
mean_b0 : ndarray
3D mean b0 volume.
mask : ndarray or None
Existing 3D binary mask, or None to auto-compute.
Returns
-------
mask : ndarray
3D boolean brain mask.
"""
if mask is None:
_, mask = median_otsu(mean_b0, median_radius=4, numpass=4)
else:
mask = np.asarray(mask, dtype=bool)
return mask
def _gradient_weights(*, log_b0, alpha=1.0):
"""Compute gradient-based edge suppression weight map.
Parameters
----------
log_b0 : ndarray
3D log-domain b0 image, shape (S, R, C).
alpha : float, optional
Edge suppression strength.
Returns
-------
weights : ndarray
Float64 weight map, same shape as log_b0.
"""
img = np.ascontiguousarray(log_b0, dtype=np.float64)
if _HAVE_VF_GRADIENT:
shape = np.array(img.shape, dtype=np.int32)
eye4 = np.eye(4, dtype=np.float64)
spacing = np.ones(3, dtype=np.float64)
grad_out, _ = _vf_gradient(img, eye4, spacing, shape, eye4)
grad_mag = np.sqrt(np.sum(grad_out**2, axis=-1))
else:
gx = ndimage.sobel(img, axis=0)
gy = ndimage.sobel(img, axis=1)
gz = ndimage.sobel(img, axis=2)
grad_mag = np.sqrt(gx**2 + gy**2 + gz**2)
return np.exp(-alpha * grad_mag)
def _normalize_coords(*, shape, coords):
"""Normalize voxel coordinates to [-1, 1] along each axis.
Parameters
----------
shape : tuple of int
Volume shape (S, R, C).
coords : ndarray
Integer coordinates, shape (N, 3).
Returns
-------
coords_norm : ndarray
Normalized float64 coordinates, shape (N, 3).
"""
coords_norm = coords.astype(np.float64)
for d, n in enumerate(shape):
if n > 1:
coords_norm[:, d] = 2.0 * coords_norm[:, d] / (n - 1) - 1.0
else:
coords_norm[:, d] = 0.0
return coords_norm
def _legendre_basis(*, coords_flat, order):
"""Build Legendre polynomial design matrix.
Parameters
----------
coords_flat : ndarray
Normalized coordinates in [-1, 1], shape (N, 3).
order : int
Maximum total polynomial degree (terms where i+j+k <= order).
Returns
-------
X : ndarray
Design matrix, shape (N, K) where K is the number of terms.
"""
from numpy.polynomial.legendre import legval
terms = [
(i, j, k)
for i in range(order + 1)
for j in range(order + 1 - i)
for k in range(order + 1 - i - j)
]
N = coords_flat.shape[0]
K = len(terms)
X = np.zeros((N, K), dtype=np.float64)
for col, (i, j, k) in enumerate(terms):
ei = np.zeros(i + 1)
ei[i] = 1.0
ej = np.zeros(j + 1)
ej[j] = 1.0
ek = np.zeros(k + 1)
ek[k] = 1.0
X[:, col] = (
legval(coords_flat[:, 0], ei)
* legval(coords_flat[:, 1], ej)
* legval(coords_flat[:, 2], ek)
)
return X
def _weighted_ridge_solve(*, X, y, weights, lambda_reg):
"""Solve weighted ridge regression min ||W^(1/2)(y - Xβ)||² + λ||β||².
Parameters
----------
X : ndarray
Design matrix, shape (N, K).
y : ndarray
Target values, shape (N,).
weights : ndarray
Non-negative regression weights, shape (N,).
lambda_reg : float
Ridge regularization strength.
Returns
-------
beta : ndarray
Coefficient vector, shape (K,).
"""
K = X.shape[1]
WX = weights[:, None] * X
A = X.T @ WX + lambda_reg * np.eye(K)
b = X.T @ (weights * y)
try:
beta = np.linalg.solve(A, b)
except np.linalg.LinAlgError:
beta, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
return beta
def _tukey_weights_py(*, residuals, c):
"""Compute Tukey biweight weights (pure Python/NumPy).
Parameters
----------
residuals : ndarray
Regression residuals, shape (N,).
c : float
Tukey breakdown constant.
Returns
-------
weights : ndarray
Tukey biweight weights in [0, 1], shape (N,).
"""
mad = np.median(np.abs(residuals)) / 0.6745
if mad < 1e-15:
return np.ones(len(residuals), dtype=np.float64)
u = residuals / (c * mad)
w = np.where(np.abs(u) < 1.0, (1.0 - u**2) ** 2, 0.0)
return w.astype(np.float64)
def _tukey_weights(*, residuals, c=4.685):
"""Compute Tukey biweight weights, using Cython backend if available.
Parameters
----------
residuals : ndarray
Regression residuals, shape (N,).
c : float, optional
Tukey breakdown constant.
Returns
-------
weights : ndarray
Tukey biweight weights in [0, 1], shape (N,).
"""
if _HAVE_CYTHON:
w = np.ones(len(residuals), dtype=np.float64)
compute_tukey_weights(np.ascontiguousarray(residuals, dtype=np.float64), w, c=c)
return w
return _tukey_weights_py(residuals=residuals, c=c)
def _polynomial_pyramid_fit(
*,
log_b0,
mask,
order,
pyramid_levels,
n_iter,
lambda_reg,
robust,
gradient_weighting,
sigma_factor=0.2,
):
"""Coarse-to-fine polynomial bias field regression.
Parameters
----------
log_b0 : ndarray
3D log-domain b0 image, shape (S, R, C).
mask : ndarray
3D boolean brain mask.
order : int
Maximum Legendre polynomial order.
pyramid_levels : tuple of int
Downsampling factors, ordered coarse-first (e.g. (4, 2, 1)).
n_iter : int
Reweighting iterations per pyramid level.
lambda_reg : float
Ridge regularization strength.
robust : bool
Apply Tukey biweight robust reweighting.
gradient_weighting : bool
Apply gradient-based edge suppression weights.
sigma_factor : float, optional
Sigma = factor * sigma_factor for Gaussian smoothing.
Returns
-------
log_bias : ndarray
Estimated log-domain bias field, same shape as log_b0.
"""
full_shape = log_b0.shape
if gradient_weighting:
grad_w_full = _gradient_weights(log_b0=log_b0)
ii, jj, kk = np.meshgrid(
np.arange(full_shape[0]),
np.arange(full_shape[1]),
np.arange(full_shape[2]),
indexing="ij",
)
full_vox_coords = np.column_stack([ii.ravel(), jj.ravel(), kk.ravel()])
full_coords_norm = _normalize_coords(shape=full_shape, coords=full_vox_coords)
X_full = _legendre_basis(coords_flat=full_coords_norm, order=order)
# Center log_b0 so the polynomial fits only spatial variation, not the
# DC offset (overall intensity level).
log_b0_dc = log_b0[mask].mean()
residual = log_b0 - log_b0_dc
log_bias = np.zeros(full_shape, dtype=np.float64)
for factor in pyramid_levels:
if factor == 1:
level_residual = residual
level_mask = mask
else:
sigma = factor * sigma_factor
smoothed = ndimage.gaussian_filter(residual, sigma=sigma)
level_residual = ndimage.zoom(smoothed, zoom=1.0 / factor, order=1)
level_mask = (
ndimage.zoom(mask.astype(np.float64), zoom=1.0 / factor, order=0) > 0.5
)
level_shape = level_residual.shape
mask_flat = level_mask.ravel()
n_masked = mask_flat.sum()
# Need at least as many data points as parameters
n_params = sum(
1
for i in range(order + 1)
for j in range(order + 1 - i)
for _ in range(order + 1 - i - j)
)
if n_masked < n_params:
continue
y = level_residual.ravel()[mask_flat]
ii_l, jj_l, kk_l = np.meshgrid(
np.arange(level_shape[0]),
np.arange(level_shape[1]),
np.arange(level_shape[2]),
indexing="ij",
)
level_coords = np.column_stack(
[
ii_l.ravel()[mask_flat],
jj_l.ravel()[mask_flat],
kk_l.ravel()[mask_flat],
]
)
coords_norm = _normalize_coords(shape=level_shape, coords=level_coords)
X = _legendre_basis(coords_flat=coords_norm, order=order)
w = np.ones(n_masked, dtype=np.float64)
if gradient_weighting:
if factor == 1:
gw = grad_w_full.ravel()[mask_flat]
else:
gw_down = ndimage.zoom(grad_w_full, zoom=1.0 / factor, order=1)
gw = gw_down.ravel()[mask_flat]
w = w * gw
beta = None
for _ in range(n_iter):
beta = _weighted_ridge_solve(X=X, y=y, weights=w, lambda_reg=lambda_reg)
residuals_iter = y - X @ beta
if robust:
w = w * _tukey_weights(residuals=residuals_iter)
if beta is None:
beta = _weighted_ridge_solve(X=X, y=y, weights=w, lambda_reg=lambda_reg)
level_bias = (X_full @ beta).reshape(full_shape)
log_bias += level_bias
residual = residual - level_bias
# Center: ensure bias_field has unit mean within mask
log_bias -= log_bias[mask].mean()
return log_bias
[docs]
def polynomial_bias_field_dwi(
data,
gtab,
*,
mask=None,
order=3,
pyramid_levels=(4, 2, 1),
n_iter=4,
lambda_reg=1e-3,
robust=True,
gradient_weighting=True,
zero_background=False,
):
"""DWI bias field correction via multi-resolution Legendre polynomial regression.
Estimates the bias field from the mean b0 volume in log space using
coarse-to-fine Legendre polynomial regression, then applies the estimated
field to all DWI volumes.
Parameters
----------
data : ndarray
4D DWI data (X, Y, Z, N).
gtab : GradientTable
Gradient table.
mask : ndarray, optional
3D binary brain mask. Auto-computed via median_otsu if None.
order : int, optional
Maximum Legendre polynomial order (terms where i+j+k <= order).
pyramid_levels : tuple of int, optional
Downsampling factors for coarse-to-fine pyramid (descending order).
n_iter : int, optional
Reweighting iterations per pyramid level.
lambda_reg : float, optional
Ridge regularization strength.
robust : bool, optional
Apply Tukey biweight robust reweighting.
gradient_weighting : bool, optional
Apply gradient-based edge suppression.
zero_background : bool, optional
If True, set the bias field to 1.0 (no correction) outside the brain
mask. If False, the raw extrapolated field values are preserved in
the returned bias_field array. Has no effect on the corrected DWI
data (background voxels are always zeroed by the brain mask).
Returns
-------
corrected : ndarray
Bias-corrected 4D DWI data, same dtype as input.
bias_field : ndarray
Estimated 3D multiplicative bias field.
"""
orig_dtype = data.dtype
mean_b0 = _get_mean_b0(data, gtab)
mask = _get_mask(mean_b0, mask)
log_b0 = np.log(np.clip(mean_b0, 1e-10, None))
log_bias = _polynomial_pyramid_fit(
log_b0=log_b0,
mask=mask,
order=order,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
if zero_background:
log_bias[~mask] = 0.0
bias_field = np.exp(log_bias)
corrected = applymask(data.astype(np.float64) / bias_field[..., None], mask).astype(
orig_dtype
)
return corrected, bias_field
def _build_bspline_design_matrix_py(*, log_b0_shape, n_control, mask_flat):
"""Build sparse B-spline design matrix (pure Python fallback).
Parameters
----------
log_b0_shape : tuple of int
Shape of the 3D volume (S, R, C).
n_control : tuple of int
Control grid dimensions (ns, nr, nc).
mask_flat : ndarray
Flattened boolean mask, shape (S*R*C,).
Returns
-------
X : scipy.sparse.csr_matrix
Design matrix, shape (N_masked, K_ctrl_total).
"""
S, R, C = log_b0_shape
ns, nr, nc = n_control
K = ns * nr * nc
mask_3d = mask_flat.reshape(log_b0_shape)
iz_all, iy_all, ix_all = np.where(mask_3d)
N = len(iz_all)
def _vox_to_ctrl_arr(vox, shape_d, n_ctrl_d):
if shape_d <= 1 or n_ctrl_d <= 1:
return np.zeros(len(vox), dtype=np.float64)
return vox.astype(np.float64) * (n_ctrl_d - 1) / (shape_d - 1)
tz = _vox_to_ctrl_arr(iz_all, S, ns)
ty = _vox_to_ctrl_arr(iy_all, R, nr)
tx = _vox_to_ctrl_arr(ix_all, C, nc)
def _bspline_basis_batch(t, n_ctrl):
"""Vectorized cubic B-spline basis.
Returns (N,4) basis values and (N,4) control indices.
"""
t = np.clip(t, 0.0, n_ctrl - 1 - 1e-10)
k = np.floor(t).astype(np.int64)
k = np.minimum(k, n_ctrl - 2)
u = t - k
u2 = u * u
u3 = u2 * u
b = np.stack(
[
(1.0 - u) ** 3 / 6.0,
(3.0 * u3 - 6.0 * u2 + 4.0) / 6.0,
(-3.0 * u3 + 3.0 * u2 + 3.0 * u + 1.0) / 6.0,
u3 / 6.0,
],
axis=-1,
) # (N, 4)
ctrl = np.stack([k - 1, k, k + 1, k + 2], axis=-1) # (N, 4)
return b, ctrl
bz, cz = _bspline_basis_batch(tz, ns) # (N, 4)
by_, cy = _bspline_basis_batch(ty, nr)
bx, cx = _bspline_basis_batch(tx, nc)
# Tensor product: (N, 4, 4, 4) via broadcasting
vals = (
bz[:, :, np.newaxis, np.newaxis]
* by_[:, np.newaxis, :, np.newaxis]
* bx[:, np.newaxis, np.newaxis, :]
)
cols = (
cz[:, :, np.newaxis, np.newaxis] * (nr * nc)
+ cy[:, np.newaxis, :, np.newaxis] * nc
+ cx[:, np.newaxis, np.newaxis, :]
)
rows = np.broadcast_to(
np.arange(N, dtype=np.int64)[:, np.newaxis, np.newaxis, np.newaxis],
(N, 4, 4, 4),
)
# Validity: all three ctrl indices must be in bounds
valid = (
(cz[:, :, np.newaxis, np.newaxis] >= 0)
& (cz[:, :, np.newaxis, np.newaxis] < ns)
& (cy[:, np.newaxis, :, np.newaxis] >= 0)
& (cy[:, np.newaxis, :, np.newaxis] < nr)
& (cx[:, np.newaxis, np.newaxis, :] >= 0)
& (cx[:, np.newaxis, np.newaxis, :] < nc)
)
return sparse.csr_matrix(
(vals[valid], (rows[valid], cols[valid])),
shape=(N, K),
dtype=np.float64,
)
def _build_bspline_design_matrix(*, log_b0_shape, n_control, mask_flat):
"""Build sparse B-spline design matrix, using Cython backend if available.
Parameters
----------
log_b0_shape : tuple of int
Shape of the 3D volume (S, R, C).
n_control : tuple of int
Control grid dimensions (ns, nr, nc).
mask_flat : ndarray
Flattened boolean mask, shape (S*R*C,).
Returns
-------
X : scipy.sparse.csr_matrix
Design matrix, shape (N_masked, K_ctrl_total).
"""
if _HAVE_CYTHON:
S, R, C = log_b0_shape
ns, nr, nc = n_control
K = ns * nr * nc
mask_3d = mask_flat.reshape(log_b0_shape).astype(np.uint8)
N_max = int(mask_flat.sum())
out_coords = np.zeros((N_max, 3), dtype=np.int64)
N_actual = int(masked_voxel_coords(np.ascontiguousarray(mask_3d), out_coords))
out_coords = out_coords[:N_actual]
def _scale(axis_coords, shape_d, n_ctrl_d):
if shape_d <= 1 or n_ctrl_d <= 1:
return np.zeros(len(axis_coords), dtype=np.float64)
return axis_coords.astype(np.float64) * (n_ctrl_d - 1) / (shape_d - 1)
grid_coords = np.column_stack(
[
_scale(out_coords[:, 0], S, ns),
_scale(out_coords[:, 1], R, nr),
_scale(out_coords[:, 2], C, nc),
]
).astype(np.float64)
n_ctrl_arr = np.array([ns, nr, nc], dtype=np.int64)
row_ptr = np.zeros(N_actual + 1, dtype=np.int64)
col_idx = np.zeros(N_actual * 64, dtype=np.int64)
values = np.zeros(N_actual * 64, dtype=np.float64)
nnz = int(
evaluate_bspline_rows(
np.ascontiguousarray(grid_coords),
n_ctrl_arr,
row_ptr,
col_idx,
values,
)
)
col_idx = col_idx[:nnz]
values = values[:nnz]
return sparse.csr_matrix(
(values, col_idx, row_ptr),
shape=(N_actual, K),
dtype=np.float64,
)
return _build_bspline_design_matrix_py(
log_b0_shape=log_b0_shape, n_control=n_control, mask_flat=mask_flat
)
def _sparse_weighted_ridge_solve(*, X_sparse, y, weights, lambda_reg):
"""Solve sparse weighted ridge regression.
The Gram matrix A = X^T W X is computed via chunked dense BLAS DGEMM,
which is substantially faster than scipy sparse×sparse multiplication
when the result is nearly dense (K < 1000).
Parameters
----------
X_sparse : scipy.sparse.csr_matrix
Design matrix, shape (N, K).
y : ndarray
Target values, shape (N,).
weights : ndarray
Non-negative regression weights, shape (N,).
lambda_reg : float
Ridge regularization strength.
Returns
-------
beta : ndarray
Coefficient vector, shape (K,).
"""
K = X_sparse.shape[1]
N = X_sparse.shape[0]
if _HAVE_CYTHON:
# Fast path: Cython direct CSR accumulation
A = np.zeros((K, K), dtype=np.float64)
b_vec = np.zeros(K, dtype=np.float64)
gram_matrix_csr(
np.asarray(X_sparse.data, dtype=np.float64),
np.asarray(X_sparse.indices, dtype=np.int32),
np.asarray(X_sparse.indptr, dtype=np.int32),
np.ascontiguousarray(weights, dtype=np.float64),
np.ascontiguousarray(y, dtype=np.float64),
A,
b_vec,
)
else:
# Chunked BLAS DGEMM: avoids sparse×sparse which is slow for dense-ish
# results (K×K), converting sparse rows to dense in blocks and using
# BLAS for the accumulation.
chunk = min(4096, N)
A = np.zeros((K, K), dtype=np.float64)
b_vec = np.zeros(K, dtype=np.float64)
for i in range(0, N, chunk):
Xc = X_sparse[i : i + chunk].toarray() # (chunk, K)
wc = weights[i : i + chunk]
A += Xc.T @ (wc[:, np.newaxis] * Xc) # BLAS DGEMM
b_vec += Xc.T.dot(wc * y[i : i + chunk])
A += lambda_reg * np.eye(K)
try:
beta = scipy_linalg.solve(A, b_vec, assume_a="pos")
except scipy_linalg.LinAlgError:
beta, _, _, _ = np.linalg.lstsq(A, b_vec, rcond=None)
return beta
def _refine_control_coeffs(*, coeffs, n_ctrl_coarse, n_ctrl_fine):
"""Trilinear interpolation of control grid coefficients.
Parameters
----------
coeffs : ndarray
Flattened control point coefficients at coarse resolution.
n_ctrl_coarse : tuple of int
Coarse control grid dimensions.
n_ctrl_fine : tuple of int
Fine control grid dimensions.
Returns
-------
fine_coeffs : ndarray
Flattened coefficients at fine resolution.
"""
coarse_grid = coeffs.reshape(n_ctrl_coarse)
zoom_factors = tuple(f / c for f, c in zip(n_ctrl_fine, n_ctrl_coarse))
fine_grid = ndimage.zoom(coarse_grid, zoom=zoom_factors, order=1)
# Trim or pad to exactly match target shape
slices = tuple(slice(0, n) for n in n_ctrl_fine)
fine_grid = fine_grid[slices]
if fine_grid.shape != tuple(n_ctrl_fine):
padded = np.zeros(n_ctrl_fine, dtype=np.float64)
src_slices = tuple(slice(0, s) for s in fine_grid.shape)
padded[src_slices] = fine_grid
fine_grid = padded
return fine_grid.ravel()
def _eval_bspline_field(*, coeffs, n_control, out_shape):
"""Evaluate B-spline field at all voxel positions.
Uses ``scipy.ndimage.map_coordinates`` with ``prefilter=False`` so that
``coeffs`` are treated directly as B-spline weights (not as values to
interpolate through). This avoids building a full-resolution sparse
design matrix and is O(N) in C rather than O(N × 64) in Python.
Parameters
----------
coeffs : ndarray
Flattened control point coefficients (B-spline weights).
n_control : tuple of int
Control grid dimensions (ns, nr, nc).
out_shape : tuple of int
Output volume shape (S, R, C).
Returns
-------
field : ndarray
Evaluated field, shape out_shape.
"""
S, R, C = out_shape
ns, nr, nc = n_control
coeff_grid = np.ascontiguousarray(coeffs.reshape(n_control), dtype=np.float64)
iz = np.linspace(0, ns - 1, S) if ns > 1 else np.zeros(S)
iy = np.linspace(0, nr - 1, R) if nr > 1 else np.zeros(R)
ix_ = np.linspace(0, nc - 1, C) if nc > 1 else np.zeros(C)
II, JJ, KK = np.meshgrid(iz, iy, ix_, indexing="ij")
coords = np.vstack([II.ravel(), JJ.ravel(), KK.ravel()])
field = ndimage.map_coordinates(
coeff_grid, coords, order=3, mode="nearest", prefilter=False
)
return field.reshape(out_shape)
def _bspline_pyramid_fit(
*,
log_b0,
mask,
n_control_points,
pyramid_levels,
n_iter,
lambda_reg,
robust,
gradient_weighting,
sigma_factor=0.2,
):
"""Coarse-to-fine B-spline bias field regression.
Parameters
----------
log_b0 : ndarray
3D log-domain b0 image, shape (S, R, C).
mask : ndarray
3D boolean brain mask.
n_control_points : tuple of int
Control grid dimensions at finest level.
pyramid_levels : tuple of int
Downsampling factors, ordered coarse-first (e.g. (4, 2, 1)).
n_iter : int
Reweighting iterations per pyramid level.
lambda_reg : float
Ridge regularization strength.
robust : bool
Apply Tukey biweight robust reweighting.
gradient_weighting : bool
Apply gradient-based edge suppression weights.
sigma_factor : float, optional
Sigma = factor * sigma_factor for Gaussian smoothing.
Returns
-------
log_bias : ndarray
Estimated log-domain bias field, same shape as log_b0.
"""
full_shape = log_b0.shape
if gradient_weighting:
grad_w_full = _gradient_weights(log_b0=log_b0)
# Center log_b0 so the B-spline fits only spatial variation, not the
# DC offset (overall intensity level).
log_b0_dc = log_b0[mask].mean()
residual = log_b0 - log_b0_dc
log_bias = np.zeros(full_shape, dtype=np.float64)
prev_coeffs = None
prev_n_ctrl = None
for factor in pyramid_levels:
n_ctrl = tuple(max(2, int(np.round(n / factor))) for n in n_control_points)
if factor == 1:
level_residual = residual
level_mask = mask
else:
sigma = factor * sigma_factor
smoothed = ndimage.gaussian_filter(residual, sigma=sigma)
level_residual = ndimage.zoom(smoothed, zoom=1.0 / factor, order=1)
level_mask = (
ndimage.zoom(mask.astype(np.float64), zoom=1.0 / factor, order=0) > 0.5
)
level_shape = level_residual.shape
mask_flat_level = level_mask.ravel()
n_masked = mask_flat_level.sum()
K = n_ctrl[0] * n_ctrl[1] * n_ctrl[2]
if n_masked < K:
continue
y = level_residual.ravel()[mask_flat_level]
X = _build_bspline_design_matrix(
log_b0_shape=level_shape,
n_control=n_ctrl,
mask_flat=mask_flat_level,
)
# Warm-start: refine coefficients from previous coarser level
if prev_coeffs is not None and prev_n_ctrl is not None:
coeffs = _refine_control_coeffs(
coeffs=prev_coeffs,
n_ctrl_coarse=prev_n_ctrl,
n_ctrl_fine=n_ctrl,
)
else:
coeffs = np.zeros(K, dtype=np.float64)
w = np.ones(n_masked, dtype=np.float64)
if gradient_weighting:
if factor == 1:
gw = grad_w_full.ravel()[mask_flat_level]
else:
gw_down = ndimage.zoom(grad_w_full, zoom=1.0 / factor, order=1)
gw = gw_down.ravel()[mask_flat_level]
w = w * gw
for _ in range(n_iter):
coeffs = _sparse_weighted_ridge_solve(
X_sparse=X, y=y, weights=w, lambda_reg=lambda_reg
)
residuals_iter = y - X @ coeffs
if robust:
w = w * _tukey_weights(residuals=residuals_iter)
prev_coeffs = coeffs
prev_n_ctrl = n_ctrl
level_bias = _eval_bspline_field(
coeffs=coeffs, n_control=n_ctrl, out_shape=full_shape
)
log_bias += level_bias
residual = residual - level_bias
# Center: ensure bias_field has unit mean within mask
log_bias -= log_bias[mask].mean()
return log_bias
def _auto_select_fit(
*,
log_b0,
mean_b0,
mask,
order,
n_control_points,
pyramid_levels,
n_iter,
lambda_reg,
robust,
gradient_weighting,
):
"""Run poly and bspline fits, return the log-bias with lower CoV.
Parameters
----------
log_b0 : ndarray
Log-domain mean b0, shape (X, Y, Z), float64.
mean_b0 : ndarray
Mean b0 in signal domain, shape (X, Y, Z), float64.
mask : ndarray
3D boolean brain mask.
order : int
Legendre polynomial order for poly fit.
n_control_points : tuple of int
B-spline control grid dimensions for bspline fit.
pyramid_levels : tuple of int
Downsampling factors for coarse-to-fine pyramid.
n_iter : int
Reweighting iterations per pyramid level.
lambda_reg : float
Ridge regularization strength.
robust : bool
Apply Tukey biweight robust reweighting.
gradient_weighting : bool
Apply gradient-based edge suppression.
Returns
-------
log_bias : ndarray
Log-domain bias field from the winning method.
"""
log_bias_poly = _polynomial_pyramid_fit(
log_b0=log_b0,
mask=mask,
order=order,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
log_bias_bspline = _bspline_pyramid_fit(
log_b0=log_b0,
mask=mask,
n_control_points=n_control_points,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
def _cov(log_bf):
"""CoV of mean b0 corrected by the given log bias field."""
corrected_b0 = mean_b0 / np.where(np.exp(log_bf) > 1e-10, np.exp(log_bf), 1.0)
vals = corrected_b0[mask]
return vals.std() / (vals.mean() + 1e-12)
cov_poly = _cov(log_bias_poly)
cov_bspline = _cov(log_bias_bspline)
if cov_poly <= cov_bspline:
logger.info(
"bias_field_correction auto: selected 'poly' " "(CoV %.4f vs bspline %.4f)",
cov_poly,
cov_bspline,
)
return log_bias_poly
logger.info(
"bias_field_correction auto: selected 'bspline' " "(CoV %.4f vs poly %.4f)",
cov_bspline,
cov_poly,
)
return log_bias_bspline
[docs]
def bias_field_correction(
data,
gtab,
*,
mask=None,
method="bspline",
order=3,
n_control_points=(8, 8, 8),
pyramid_levels=(4, 2, 1),
n_iter=4,
lambda_reg=1e-3,
robust=True,
gradient_weighting=True,
return_bias_field=False,
zero_background=False,
):
"""Top-level DWI bias field correction via regression.
Estimates a smooth multiplicative bias field from the mean b0 volume
using polynomial or B-spline regression in log space, then applies the
correction uniformly to all DWI volumes.
Parameters
----------
data : ndarray
4D DWI data (X, Y, Z, N).
gtab : GradientTable
Gradient table.
mask : ndarray, optional
3D binary brain mask. If None, computed via median_otsu.
method : str, optional
Bias correction method:
- ``"poly"``: Legendre polynomial regression — fast, low-parameter.
- ``"bspline"``: Cubic B-spline regression — more flexible.
- ``"auto"``: Run both methods and return the one with lower
Coefficient of Variation within the brain mask. The chosen method
is logged at INFO level.
order : int, optional
Maximum Legendre polynomial degree (used only for method="poly").
n_control_points : tuple of int, optional
Control grid dimensions at finest level (used only for
method="bspline").
pyramid_levels : tuple of int, optional
Downsampling factors for coarse-to-fine pyramid (descending order).
n_iter : int, optional
Reweighting iterations per pyramid level.
lambda_reg : float, optional
Ridge regularization strength.
robust : bool, optional
Apply Tukey biweight robust reweighting at each level.
gradient_weighting : bool, optional
Weight regression by edge-suppression map derived from the
image gradient.
return_bias_field : bool, optional
If True, return the bias field alongside the corrected data.
zero_background : bool, optional
If True, set the bias field to 1.0 (no correction) outside the brain
mask. If False, the raw extrapolated field values are preserved in
the returned bias_field array. Has no effect on the corrected DWI
data (background voxels are always zeroed by the brain mask).
Returns
-------
corrected : ndarray
Bias-corrected DWI, same dtype as input.
bias_field : ndarray
3D multiplicative bias field (only returned if
return_bias_field=True).
"""
orig_dtype = data.dtype
mean_b0 = _get_mean_b0(data, gtab)
mask = _get_mask(mean_b0, mask)
log_b0 = np.log(np.clip(mean_b0.astype(np.float64), 1e-10, None))
if method == "poly":
log_bias = _polynomial_pyramid_fit(
log_b0=log_b0,
mask=mask,
order=order,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
elif method == "bspline":
log_bias = _bspline_pyramid_fit(
log_b0=log_b0,
mask=mask,
n_control_points=n_control_points,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
elif method == "auto":
log_bias = _auto_select_fit(
log_b0=log_b0,
mean_b0=mean_b0,
mask=mask,
order=order,
n_control_points=n_control_points,
pyramid_levels=pyramid_levels,
n_iter=n_iter,
lambda_reg=lambda_reg,
robust=robust,
gradient_weighting=gradient_weighting,
)
else:
raise ValueError(f"method must be 'poly', 'bspline', or 'auto', got '{method}'")
if zero_background:
log_bias[~mask] = 0.0
bias_field = np.exp(log_bias)
corrected = applymask(data.astype(np.float64) / bias_field[..., None], mask).astype(
orig_dtype
)
if return_bias_field:
return corrected, bias_field
return corrected