Source code for dipy.nn.tf.deepn4

#!/usr/bin/python
"""
Class and helper functions for fitting the DeepN4 model.
"""

import logging

import numpy as np
from scipy.ndimage import gaussian_filter

from dipy.data import get_fnames
from dipy.nn.utils import normalize, recover_img, set_logger_level, transform_img
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.optpkg import optional_package

tf, have_tf, _ = optional_package("tensorflow")
if have_tf:
    from tensorflow.keras.layers import (
        Concatenate,
        Conv3D,
        Conv3DTranspose,
        GroupNormalization,
        Layer,
        LeakyReLU,
        MaxPool3D,
    )
    from tensorflow.keras.models import Model
else:

    class Model:
        pass

    class Layer:
        pass

    logging.warning(
        "This model requires Tensorflow.\
                    Please install these packages using \
                    pip. If using mac, please refer to this \
                    link for installation. \
                    https://github.com/apple/tensorflow_macos"
    )


logging.basicConfig()
logger = logging.getLogger("deepn4")


[docs] class EncoderBlock(Layer): def __init__(self, out_channels, kernel_size, strides, padding): super(EncoderBlock, self).__init__() self.conv3d = Conv3D( out_channels, kernel_size, strides=strides, padding=padding, use_bias=False ) self.instnorm = GroupNormalization( groups=-1, axis=-1, center=False, scale=False ) self.activation = LeakyReLU(0.01)
[docs] def call(self, input): x = self.conv3d(input) x = self.instnorm(x) x = self.activation(x) return x
[docs] class DecoderBlock(Layer): def __init__(self, out_channels, kernel_size, strides, padding): super(DecoderBlock, self).__init__() self.conv3d = Conv3DTranspose( out_channels, kernel_size, strides=strides, padding=padding, use_bias=False ) self.instnorm = GroupNormalization( groups=-1, axis=-1, center=False, scale=False ) self.activation = LeakyReLU(0.01)
[docs] def call(self, input): x = self.conv3d(input) x = self.instnorm(x) x = self.activation(x) return x
[docs] def UNet3D(input_shape): inputs = tf.keras.Input(input_shape) # Encode x = EncoderBlock(32, kernel_size=3, strides=1, padding="same")(inputs) syn0 = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn0) x = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x) syn1 = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn1) x = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x) syn2 = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn2) x = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = EncoderBlock(512, kernel_size=3, strides=1, padding="same")(x) # Last layer without relu x = Conv3D(512, kernel_size=1, strides=1, padding="same")(x) x = DecoderBlock(512, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn2]) x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(256, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn1]) x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(128, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn0]) x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(1, kernel_size=1, strides=1, padding="valid")(x) # Last layer without relu out = Conv3DTranspose(1, kernel_size=1, strides=1, padding="valid")(x) return Model(inputs, out)
[docs] class DeepN4: """ This class is intended for the DeepN4 model. The DeepN4 model :footcite:p:`Kanakaraj2024` predicts the bias field for magnetic field inhomogeneity correction on T1-weighted images. References ---------- .. footbibliography:: """ @warning_for_keywords() @doctest_skip_parser def __init__(self, *, verbose=False): r""" To obtain the pre-trained model, use fetch_default_weights() like: >>> deepn4_model = DeepN4() # skip if not have_tf >>> deepn4_model.fetch_default_weights() # skip if not have_tf This model is designed to take as input file T1 signal and predict bias field. Effectively, this model is mimicking bias correction. Parameters ---------- verbose : bool, optional Whether to show information about the processing. """ if not have_tf: raise tf() log_level = "INFO" if verbose else "CRITICAL" set_logger_level(log_level, logger) # Synb0 network load self.model = UNet3D(input_shape=(128, 128, 128, 1))
[docs] def fetch_default_weights(self): """ Load the model pre-training weights to use for the fitting. """ fetch_model_weights_path = get_fnames(name="deepn4_default_weights") self.load_model_weights(fetch_model_weights_path)
[docs] def load_model_weights(self, weights_path): """ Load the custom pre-training weights to use for the fitting. get_fnames('deepn4_default_weights'). Parameters ---------- weights_path : str Path to the file containing the weights (hdf5, saved by tensorflow) """ try: self.model.load_weights(weights_path) except ValueError as e: raise ValueError( "Expected input for the provided model weights \ do not match the declared model" ) from e
def __predict(self, x_test): """ Internal prediction function Predict bias field from input T1 signal Parameters ---------- x_test : np.ndarray (128, 128, 128, 1) Image should match the required shape of the model. Returns ------- np.ndarray (128, 128, 128) Predicted bias field """ return self.model.predict(x_test)
[docs] def pad(self, img, sz): tmp = np.zeros((sz, sz, sz)) diff = int((sz - img.shape[0]) / 2) lx = max(diff, 0) lX = min(img.shape[0] + diff, sz) diff = (img.shape[0] - sz) / 2 rx = max(int(np.floor(diff)), 0) rX = min(img.shape[0] - int(np.ceil(diff)), img.shape[0]) diff = int((sz - img.shape[1]) / 2) ly = max(diff, 0) lY = min(img.shape[1] + diff, sz) diff = (img.shape[1] - sz) / 2 ry = max(int(np.floor(diff)), 0) rY = min(img.shape[1] - int(np.ceil(diff)), img.shape[1]) diff = int((sz - img.shape[2]) / 2) lz = max(diff, 0) lZ = min(img.shape[2] + diff, sz) diff = (img.shape[2] - sz) / 2 rz = max(int(np.floor(diff)), 0) rZ = min(img.shape[2] - int(np.ceil(diff)), img.shape[2]) tmp[lx:lX, ly:lY, lz:lZ] = img[rx:rX, ry:rY, rz:rZ] return tmp, [lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ]
[docs] def load_resample(self, subj): input_data, [lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ] = self.pad( subj, 128 ) in_max = np.percentile(input_data[np.nonzero(input_data)], 99.99) input_data = normalize(input_data, min_v=0, max_v=in_max, new_min=0, new_max=1) input_data = np.squeeze(input_data) input_vols = np.zeros((1, 128, 128, 128, 1)) input_vols[0, :, :, :, 0] = input_data return ( tf.convert_to_tensor(input_vols, dtype=tf.float32), lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ, in_max, )
[docs] def predict(self, img, img_affine, *, voxsize=(1, 1, 1)): """Wrapper function to facilitate prediction of larger dataset. The function will mask, normalize, split, predict and 're-assemble' the data as a volume. Parameters ---------- input_file : string Path to the T1 scan Returns ------- final_corrected : np.ndarray (x, y, z) Predicted bias corrected image. The volume has matching shape to the input data """ # Preprocess input data (resample, normalize, and pad) resampled_T1, inv_affine, mid_shape, offset_array, scale, crop_vs, pad_vs = ( transform_img(img, img_affine, voxsize=voxsize) ) (in_features, lx, lX, ly, lY, lz, lZ, rx, rX, ry, rY, rz, rZ, in_max) = ( self.load_resample(resampled_T1) ) # Run the model to get the bias field logfield = self.__predict(in_features) field = np.exp(logfield) field = field.squeeze() # Postprocess predicted field (reshape - unpad, smooth the field, # upsample) final_field = np.zeros( [resampled_T1.shape[0], resampled_T1.shape[1], resampled_T1.shape[2]] ) final_field[rx:rX, ry:rY, rz:rZ] = field[lx:lX, ly:lY, lz:lZ] final_fields = gaussian_filter(final_field, sigma=3) upsample_final_field = recover_img( final_fields, inv_affine, mid_shape, img.shape, offset_array, voxsize, scale, crop_vs, pad_vs, ) # Correct the image THRESHOLD = 0.5 below_threshold_mask = np.abs(upsample_final_field) < THRESHOLD with np.errstate(divide="ignore", invalid="ignore"): final_corrected = np.where( below_threshold_mask, 0, img / upsample_final_field ) return final_corrected