Source code for dipy.nn.tf.synb0

#!/usr/bin/python
"""
Class and helper functions for fitting the Synb0 model.
"""

import logging

import numpy as np

from dipy.data import get_fnames
from dipy.nn.utils import normalize, set_logger_level, unnormalize
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.optpkg import optional_package

tf, have_tf, _ = optional_package("tensorflow", min_version="2.0.0")
if have_tf:
    from tensorflow.keras.layers import (
        Concatenate,
        Conv3D,
        Conv3DTranspose,
        GroupNormalization,
        Layer,
        LeakyReLU,
        MaxPool3D,
    )
    from tensorflow.keras.models import Model
else:

    class Model:
        pass

    class Layer:
        pass

    logging.warning(
        "This model requires Tensorflow.\
                    Please install these packages using \
                    pip. If using mac, please refer to this \
                    link for installation. \
                    https://github.com/apple/tensorflow_macos"
    )

logging.basicConfig()
logger = logging.getLogger("synb0")


[docs] class EncoderBlock(Layer): def __init__(self, out_channels, kernel_size, strides, padding): super(EncoderBlock, self).__init__() self.conv3d = Conv3D( out_channels, kernel_size, strides=strides, padding=padding, use_bias=False ) self.instnorm = GroupNormalization(groups=-1) self.activation = LeakyReLU(0.01)
[docs] def call(self, input): x = self.conv3d(input) x = self.instnorm(x) x = self.activation(x) return x
[docs] class DecoderBlock(Layer): def __init__(self, out_channels, kernel_size, strides, padding): super(DecoderBlock, self).__init__() self.conv3d = Conv3DTranspose( out_channels, kernel_size, strides=strides, padding=padding, use_bias=False ) self.instnorm = GroupNormalization(groups=-1) self.activation = LeakyReLU(0.01)
[docs] def call(self, input): x = self.conv3d(input) x = self.instnorm(x) x = self.activation(x) return x
[docs] def UNet3D(input_shape): inputs = tf.keras.Input(input_shape) # Encode x = EncoderBlock(32, kernel_size=3, strides=1, padding="same")(inputs) syn0 = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn0) x = EncoderBlock(64, kernel_size=3, strides=1, padding="same")(x) syn1 = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn1) x = EncoderBlock(128, kernel_size=3, strides=1, padding="same")(x) syn2 = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = MaxPool3D()(syn2) x = EncoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = EncoderBlock(512, kernel_size=3, strides=1, padding="same")(x) # Last layer without relu x = Conv3D(512, kernel_size=1, strides=1, padding="same")(x) x = DecoderBlock(512, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn2]) x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(256, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(256, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn1]) x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(128, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(128, kernel_size=2, strides=2, padding="valid")(x) x = Concatenate()([x, syn0]) x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(64, kernel_size=3, strides=1, padding="same")(x) x = DecoderBlock(1, kernel_size=1, strides=1, padding="valid")(x) # Last layer without relu out = Conv3DTranspose(1, kernel_size=1, strides=1, padding="valid")(x) return Model(inputs, out)
[docs] class Synb0: """ This class is intended for the Synb0 model. Synb0 :footcite:p:`Schilling2019`, :footcite:p:`Schilling2020` uses a neural network to synthesize a b0 volume for distortion correction in DWI images. The model is the deep learning part of the Synb0-Disco pipeline, thus stand-alone usage is not recommended. References ---------- .. footbibliography:: """ @doctest_skip_parser @warning_for_keywords() def __init__(self, *, verbose=False): r""" The model was pre-trained for usage on pre-processed images following the synb0-disco pipeline. One can load their own weights using load_model_weights. This model is designed to take as input a b0 image and a T1 weighted image. It was designed to predict a b-inf image. Parameters ---------- verbose : bool, optional Whether to show information about the processing. """ if not have_tf: raise tf() log_level = "INFO" if verbose else "CRITICAL" set_logger_level(log_level, logger) # Synb0 network load self.model = UNet3D(input_shape=(80, 80, 96, 2))
[docs] def fetch_default_weights(self, idx): r""" Load the model pre-training weights to use for the fitting. While the user can load different weights, the function is mainly intended for the class function 'predict'. Parameters ---------- idx : int The idx of the default weights. It can be from 0~4. """ fetch_model_weights_path = get_fnames(name="synb0_default_weights") print(f"fetched {fetch_model_weights_path[idx]}") self.load_model_weights(fetch_model_weights_path[idx])
[docs] def load_model_weights(self, weights_path): r""" Load the custom pre-training weights to use for the fitting. Parameters ---------- weights_path : str Path to the file containing the weights (hdf5, saved by tensorflow) """ try: self.model.load_weights(weights_path) except ValueError as e: raise ValueError( "Expected input for the provided model weights \ do not match the declared model" ) from e
def __predict(self, x_test): r""" Internal prediction function Parameters ---------- x_test : np.ndarray (batch, 80, 80, 96, 2) Image should match the required shape of the model. Returns ------- np.ndarray (...) or (batch, ...) Reconstructed b-inf image(s) """ return self.model.predict(x_test)
[docs] @warning_for_keywords() def predict(self, b0, T1, *, batch_size=None, average=True): r""" Wrapper function to facilitate prediction of larger dataset. The function will pad the data to meet the required shape of image. Note that the b0 and T1 image should have the same shape Parameters ---------- b0 : np.ndarray (batch, 77, 91, 77) or (77, 91, 77) For a single image, input should be a 3D array. If multiple images, there should also be a batch dimension. T1 : np.ndarray (batch, 77, 91, 77) or (77, 91, 77) For a single image, input should be a 3D array. If multiple images, there should also be a batch dimension. batch_size : int Number of images per prediction pass. Only available if data is provided with a batch dimension. Consider lowering it if you get an out of memory error. Increase it if you want it to be faster and have a lot of data. If None, batch_size will be set to 1 if the provided image has a batch dimension. Default is None average : bool Whether the function follows the Synb0-Disco pipeline and averages the prediction of 5 different models. If False, it uses the loaded weights for prediction. Default is True. Returns ------- pred_output : np.ndarray (...) or (batch, ...) Reconstructed b-inf image(s) """ # Check if shape is as intended if ( all([b0.shape[1:] != (77, 91, 77), b0.shape != (77, 91, 77)]) or b0.shape != T1.shape ): raise ValueError( "Expected shape (batch, 77, 91, 77) or \ (77, 91, 77) for both inputs" ) dim = len(b0.shape) # Add batch dimension if not provided if dim == 3: T1 = np.expand_dims(T1, 0) b0 = np.expand_dims(b0, 0) shape = b0.shape # Pad the data to match the model's input shape T1 = np.pad(T1, ((0, 0), (2, 1), (3, 2), (2, 1)), "constant") b0 = np.pad(b0, ((0, 0), (2, 1), (3, 2), (2, 1)), "constant") # Normalize the data. p99 = np.percentile(b0, 99, axis=(1, 2, 3)) for i in range(shape[0]): T1[i] = normalize(T1[i], min_v=0, max_v=150, new_min=-1, new_max=1) b0[i] = normalize(b0[i], min_v=0, max_v=p99[i], new_min=-1, new_max=1) if dim == 3: if batch_size is not None: logger.warning( "Batch size specified, but not used", "due to the input not having \ a batch dimension", ) batch_size = 1 # Prediction stage if average: mean_pred = np.zeros(shape + (5,), dtype=np.float32) for i in range(5): self.fetch_default_weights(i) temp = np.stack([b0, T1], -1) input_data = np.moveaxis(temp, 3, 1).astype(np.float32) prediction = np.zeros((shape[0], 80, 80, 96, 1), dtype=np.float32) for batch_idx in range(batch_size, shape[0] + 1, batch_size): temp_input = input_data[batch_idx - batch_size : batch_idx] temp_pred = self.__predict(temp_input) prediction[batch_idx - batch_size : batch_idx] = temp_pred remainder = np.mod(shape[0], batch_size) if remainder != 0: temp_pred = self.__predict(input_data[-remainder:]) prediction[-remainder:] = temp_pred for j in range(shape[0]): temp_pred = unnormalize(prediction[j], -1, 1, 0, p99[j]) prediction[j] = temp_pred prediction = prediction[:, 2:-1, 2:-1, 3:-2, 0] prediction = np.moveaxis(prediction, 1, -1) mean_pred[..., i] = prediction prediction = np.mean(mean_pred, axis=-1) else: temp = np.stack([b0, T1], -1) input_data = np.moveaxis(temp, 3, 1).astype(np.float32) prediction = np.zeros((shape[0], 80, 80, 96, 1), dtype=np.float32) for batch_idx in range(batch_size, shape[0] + 1, batch_size): temp_input = input_data[batch_idx - batch_size : batch_idx] temp_pred = self.__predict(temp_input) prediction[:batch_idx] = temp_pred remainder = np.mod(shape[0], batch_size) if remainder != 0: temp_pred = self.__predict(input_data[-remainder:]) prediction[-remainder:] = temp_pred for j in range(shape[0]): prediction[j] = unnormalize(prediction[j], -1, 1, 0, p99[j]) prediction = prediction[:, 2:-1, 2:-1, 3:-2, 0] prediction = np.moveaxis(prediction, 1, -1) if dim == 3: prediction = prediction[0] return prediction