#!/usr/bin/python
"""Class and helper functions for fitting the DeepN4 model."""
import numpy as np
from scipy.ndimage import gaussian_filter
from dipy.data import get_fnames
from dipy.nn.utils import normalize, recover_img, set_logger_level, transform_img
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.logging import logger
from dipy.utils.optpkg import optional_package
tf, have_tf, _ = optional_package("tensorflow", min_version="2.18.0")
if have_tf:
from tensorflow.keras.layers import (
Concatenate,
Conv3D,
Conv3DTranspose,
GroupNormalization,
Layer,
LeakyReLU,
MaxPool3D,
)
from tensorflow.keras.models import Model
else:
class Model:
pass
class Layer:
pass
logger.warning(
"This model requires Tensorflow.\
Please install these packages using \
pip. If using mac, please refer to this \
link for installation. \
https://github.com/apple/tensorflow_macos"
)
[docs]
class EncoderBlock(Layer):
"""Encoder block for the 3D U-Net.
Parameters
----------
out_channels : int
Number of output channels.
kernel_size : int
Size of the convolutional kernel.
strides : int
Stride of the convolution.
padding : str
Padding for the convolution ('same' or 'valid').
"""
def __init__(self, out_channels, kernel_size, strides, padding):
super(EncoderBlock, self).__init__()
self.conv3d = Conv3D(
out_channels, kernel_size, strides=strides, padding=padding, use_bias=False
)
self.instnorm = GroupNormalization(
groups=-1, axis=-1, epsilon=1e-05, center=False, scale=False
)
self.activation = LeakyReLU(0.01)
[docs]
def call(self, input):
"""Forward pass of the EncoderBlock.
Parameters
----------
input : tf.Tensor
Input tensor.
Returns
-------
tf.Tensor
Output tensor.
"""
x = self.conv3d(input)
x = self.instnorm(x)
x = self.activation(x)
return x
[docs]
class DecoderBlock(Layer):
"""Decoder block for the 3D U-Net.
Parameters
----------
out_channels : int
Number of output channels.
kernel_size : int
Size of the convolutional kernel.
strides : int
Stride of the convolution.
padding : str
Padding for the convolution ('same' or 'valid').
"""
def __init__(self, out_channels, kernel_size, strides, padding):
super(DecoderBlock, self).__init__()
self.conv3d = Conv3DTranspose(
out_channels, kernel_size, strides=strides, padding=padding, use_bias=False
)
self.instnorm = GroupNormalization(
groups=-1, axis=-1, epsilon=1e-05, center=False, scale=False
)
self.activation = LeakyReLU(0.01)
[docs]
def call(self, input):
"""Forward pass of the DecoderBlock.
Parameters
----------
input : tf.Tensor
Input tensor.
Returns
-------
tf.Tensor
Output tensor.
"""
x = self.conv3d(input)
x = self.instnorm(x)
x = self.activation(x)
return x
[docs]
def UNet3D(input_shape):
"""3D U-Net architecture for the DeepN4 model.
Parameters
----------
input_shape : tuple
Shape of the input tensor.
Returns
-------
tf.keras.Model
The 3D U-Net model.
"""
inputs = tf.keras.Input(input_shape)
# Encode
x = EncoderBlock(32, kernel_size=3, strides=1, padding="same")(inputs)
syn0 = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x)
x = MaxPool3D()(syn0)
x = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x)
syn1 = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x)
x = MaxPool3D()(syn1)
x = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x)
syn2 = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x)
x = MaxPool3D()(syn2)
x = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x)
x = EncoderBlock(512, kernel_size=3, strides=1, padding="same")(x)
# Last layer without relu
x = Conv3D(512, kernel_size=1, strides=1, padding="same")(x)
x = DecoderBlock(512, kernel_size=2, strides=2, padding="valid")(x)
x = Concatenate()([x, syn2])
x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(256, kernel_size=2, strides=2, padding="valid")(x)
x = Concatenate()([x, syn1])
x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(128, kernel_size=2, strides=2, padding="valid")(x)
x = Concatenate()([x, syn0])
x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x)
x = DecoderBlock(1, kernel_size=1, strides=1, padding="valid")(x)
# Last layer without relu
out = Conv3DTranspose(1, kernel_size=1, strides=1, padding="valid")(x)
return Model(inputs, out)
[docs]
class DeepN4:
"""This class is intended for the DeepN4 model.
The DeepN4 model :footcite:p:`Kanakaraj2024` predicts the bias field for
magnetic field inhomogeneity correction on T1-weighted images.
References
----------
.. footbibliography::
"""
@warning_for_keywords()
@doctest_skip_parser
def __init__(self, *, verbose=False):
"""Model initialization
To obtain the pre-trained model, use fetch_default_weights() like:
>>> deepn4_model = DeepN4() # skip if not have_tf
>>> deepn4_model.fetch_default_weights() # skip if not have_tf
This model is designed to take as input file T1 signal and predict
bias field. Effectively, this model is mimicking bias correction.
Parameters
----------
verbose : bool, optional
Whether to show information about the processing.
"""
if not have_tf:
raise tf()
log_level = "INFO" if verbose else "CRITICAL"
set_logger_level(log_level, logger)
# Synb0 network load
self.model = UNet3D(input_shape=(128, 128, 128, 1))
self.fetch_default_weights()
[docs]
def fetch_default_weights(self):
"""Load the model pre-training weights to use for the fitting."""
fetch_model_weights_path = get_fnames(name="deepn4_default_tf_weights")
self.load_model_weights(fetch_model_weights_path)
[docs]
def load_model_weights(self, weights_path):
"""Load the custom pre-training weights to use for the fitting.
Parameters
----------
weights_path : str
Path to the file containing the weights (hdf5, saved by tensorflow)
"""
try:
self.model.load_weights(weights_path)
except ValueError as e:
raise ValueError(
"Expected input for the provided model weights \
do not match the declared model"
) from e
def __predict(self, x_test):
"""Internal prediction function
Predict bias field from input T1 signal
Parameters
----------
x_test : np.ndarray (128, 128, 128, 1)
Image should match the required shape of the model.
Returns
-------
np.ndarray (128, 128, 128)
Predicted bias field
"""
return self.model.predict(x_test)[..., 0]
[docs]
def pad(self, img, sz):
"""Pad or crop the image to the specified size.
Parameters
----------
img : np.ndarray
The image to pad or crop.
sz : int
The target size for each dimension.
Returns
-------
tmp : np.ndarray
The padded or cropped image.
indices : list
A list of indices used for padding and cropping.
"""
tmp = np.zeros((sz, sz, sz))
diff = int((sz - img.shape[0]) / 2)
lx = max(diff, 0)
lX = min(img.shape[0] + diff, sz)
diff = (img.shape[0] - sz) / 2
rx = max(int(np.floor(diff)), 0)
rX = min(img.shape[0] - int(np.ceil(diff)), img.shape[0])
diff = int((sz - img.shape[1]) / 2)
ly = max(diff, 0)
lY = min(img.shape[1] + diff, sz)
diff = (img.shape[1] - sz) / 2
ry = max(int(np.floor(diff)), 0)
rY = min(img.shape[1] - int(np.ceil(diff)), img.shape[1])
diff = int((sz - img.shape[2]) / 2)
lz = max(diff, 0)
lZ = min(img.shape[2] + diff, sz)
diff = (img.shape[2] - sz) / 2
rz = max(int(np.floor(diff)), 0)
rZ = min(img.shape[2] - int(np.ceil(diff)), img.shape[2])
tmp[lx:lX, ly:lY, lz:lZ] = img[rx:rX, ry:rY, rz:rZ]
return tmp, [lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ]
[docs]
def load_resample(self, subj):
"""Preprocess the subject image for the model.
This includes padding, normalization, and converting to a tensor.
Parameters
----------
subj : np.ndarray
The subject image.
Returns
-------
tuple
A tuple containing the processed tensor and auxiliary parameters.
"""
input_data, [lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ] = self.pad(
subj, 128
)
in_max = np.percentile(input_data[np.nonzero(input_data)], 99.99)
input_data = normalize(input_data, min_v=0, max_v=in_max, new_min=0, new_max=1)
input_data = np.squeeze(input_data)
input_vols = np.zeros((1, 128, 128, 128, 1))
input_vols[0, :, :, :, 0] = input_data
return (
tf.convert_to_tensor(input_vols, dtype=tf.float32),
lx,
lX,
ly,
lY,
lz,
lZ,
rx,
rX,
ry,
rY,
rz,
rZ,
in_max,
)
[docs]
def predict(self, img, img_affine, *, threshold=0.5):
"""Wrapper function to facilitate prediction of larger dataset.
The function will mask, normalize, split, predict and 're-assemble'
the data as a volume.
Parameters
----------
img : np.ndarray
T1 image to predict and apply bias field
img_affine : np.ndarray (4, 4)
Affine matrix for the T1 image
threshold : float, optional
Threshold for cleaning the final correction field
Returns
-------
final_corrected : np.ndarray (x, y, z)
Predicted bias corrected image.
The volume has matching shape to the input data
"""
# Preprocess input data (resample, normalize, and pad)
resampled_T1, params = transform_img(
img,
img_affine,
target_voxsize=(2.0, 2.0, 2.0),
final_size=(128, 128, 128),
order=3,
)
(in_features, lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ, in_max) = (
self.load_resample(resampled_T1)
)
# Run the model to get the bias field
logfield = self.__predict(in_features)
field = np.exp(logfield)
field = field.squeeze()
# Postprocess predicted field (reshape - unpad, smooth the field,
# upsample)
final_field = np.zeros(
[resampled_T1.shape[0], resampled_T1.shape[1], resampled_T1.shape[2]]
)
final_field[rx:rX, ry:rY, rz:rZ] = field[lx:lX, ly:lY, lz:lZ]
final_fields = gaussian_filter(final_field, sigma=3)
upsample_final_field = recover_img(final_fields, params)
# Correct the image
below_threshold_mask = np.abs(upsample_final_field) < threshold
with np.errstate(divide="ignore", invalid="ignore"):
final_corrected = np.where(
below_threshold_mask, 0, img / upsample_final_field
)
return final_corrected