Source code for dipy.nn.tf.evac

#!/usr/bin/python
"""
Class and helper functions for fitting the EVAC+ model.
"""

import logging

import numpy as np

from dipy.align.reslice import reslice
from dipy.data import get_fnames
from dipy.nn.utils import (
    normalize,
    recover_img,
    set_logger_level,
    transform_img,
)
from dipy.segment.utils import remove_holes_and_islands
from dipy.testing.decorators import doctest_skip_parser, warning_for_keywords
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package

tf, have_tf, _ = optional_package("tensorflow", min_version="2.0.0")
if have_tf:
    from tensorflow.keras.layers import (
        Add,
        Concatenate,
        Conv3D,
        Conv3DTranspose,
        Dropout,
        Layer,
        LayerNormalization,
        ReLU,
        Softmax,
    )
    from tensorflow.keras.models import Model
else:

    class Model:
        pass

    class Layer:
        pass

    logging.warning(
        "This model requires Tensorflow.\
                    Please install these packages using \
                    pip. If using mac, please refer to this \
                    link for installation. \
                    https://github.com/apple/tensorflow_macos"
    )

logging.basicConfig()
logger = logging.getLogger("EVAC+")


[docs] def prepare_img(image): """ Function to prepare image for model input Specific to EVAC+ Parameters ---------- image : np.ndarray Input image Returns ------- input_data : dict """ input1 = np.moveaxis(image, -1, 0) input1 = np.expand_dims(input1, -1) input2, _ = reslice(image, np.eye(4), (1, 1, 1), (2, 2, 2)) input2 = np.moveaxis(input2, -1, 0) input2 = np.expand_dims(input2, -1) input3, _ = reslice(image, np.eye(4), (1, 1, 1), (4, 4, 4)) input3 = np.moveaxis(input3, -1, 0) input3 = np.expand_dims(input3, -1) input4, _ = reslice(image, np.eye(4), (1, 1, 1), (8, 8, 8)) input4 = np.moveaxis(input4, -1, 0) input4 = np.expand_dims(input4, -1) input5, _ = reslice(image, np.eye(4), (1, 1, 1), (16, 16, 16)) input5 = np.moveaxis(input5, -1, 0) input5 = np.expand_dims(input5, -1) input_data = { "input_1": input1, "input_2": input2, "input_3": input3, "input_4": input4, "input_5": input5, } return input_data
[docs] class Block(Layer): @warning_for_keywords() def __init__( self, out_channels, kernel_size, strides, padding, drop_r, n_layers, *, layer_type="down", ): super(Block, self).__init__() self.layer_list = [] self.layer_list2 = [] self.n_layers = n_layers for _ in range(n_layers): self.layer_list.append( Conv3D(out_channels, kernel_size, strides=strides, padding=padding) ) self.layer_list.append(Dropout(drop_r)) self.layer_list.append(LayerNormalization()) self.layer_list.append(ReLU()) if layer_type == "down": self.layer_list2.append(Conv3D(1, 2, strides=2, padding="same")) self.layer_list2.append(ReLU()) elif layer_type == "up": self.layer_list2.append(Conv3DTranspose(1, 2, strides=2, padding="same")) self.layer_list2.append(ReLU()) self.channel_sum = ChannelSum() self.add = Add()
[docs] def call(self, input, passed): x = input for layer in self.layer_list: x = layer(x) x = self.channel_sum(x) fwd = self.add([x, passed]) x = fwd for layer in self.layer_list2: x = layer(x) return fwd, x
[docs] class ChannelSum(Layer): def __init__(self): super(ChannelSum, self).__init__()
[docs] def call(self, inputs): return tf.reduce_sum(inputs, axis=-1, keepdims=True)
[docs] @warning_for_keywords() def init_model(*, model_scale=16): """ Function to create model for EVAC+ Parameters ---------- model_scale : int, optional The scale of the model Should match the saved weights from fetcher Default is 16 Returns ------- model : tf.keras.Model """ inputs = tf.keras.Input(shape=(128, 128, 128, 1), name="input_1") raw_input_2 = tf.keras.Input(shape=(64, 64, 64, 1), name="input_2") raw_input_3 = tf.keras.Input(shape=(32, 32, 32, 1), name="input_3") raw_input_4 = tf.keras.Input(shape=(16, 16, 16, 1), name="input_4") raw_input_5 = tf.keras.Input(shape=(8, 8, 8, 1), name="input_5") # Encode fwd1, x = Block( model_scale, kernel_size=5, strides=1, padding="same", drop_r=0.2, n_layers=1 )(inputs, inputs) x = Concatenate()([x, raw_input_2]) fwd2, x = Block( model_scale * 2, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=2, )(x, x) x = Concatenate()([x, raw_input_3]) fwd3, x = Block( model_scale * 4, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=3, )(x, x) x = Concatenate()([x, raw_input_4]) fwd4, x = Block( model_scale * 8, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=3, )(x, x) x = Concatenate()([x, raw_input_5]) _, up = Block( model_scale * 16, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=3, layer_type="up", )(x, x) x = Concatenate()([fwd4, up]) _, up = Block( model_scale * 8, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=3, layer_type="up", )(x, up) x = Concatenate()([fwd3, up]) _, up = Block( model_scale * 4, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=3, layer_type="up", )(x, up) x = Concatenate()([fwd2, up]) _, up = Block( model_scale * 2, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=2, layer_type="up", )(x, up) x = Concatenate()([fwd1, up]) _, pred = Block( model_scale, kernel_size=5, strides=1, padding="same", drop_r=0.5, n_layers=1, layer_type="none", )(x, up) pred = Conv3D(2, 1, padding="same")(pred) output = Softmax(axis=-1)(pred) model = Model( { "input_1": inputs, "input_2": raw_input_2, "input_3": raw_input_3, "input_4": raw_input_4, "input_5": raw_input_5, }, output[..., 0], ) return model
[docs] class EVACPlus: """ This class is intended for the EVAC+ model. The EVAC+ model :footcite:p:`Park2024` is a deep learning neural network for brain extraction. It uses a V-net architecture combined with multi-resolution input data, an additional conditional random field (CRF) recurrent layer and supplementary Dice loss term for this recurrent layer. References ---------- .. footbibliography:: """ @doctest_skip_parser @warning_for_keywords() def __init__(self, *, verbose=False): """ The model was pre-trained for usage on brain extraction of T1 images. This model is designed to take as input a T1 weighted image. Parameters ---------- verbose : bool, optional Whether to show information about the processing. """ if not have_tf: raise tf() log_level = "INFO" if verbose else "CRITICAL" set_logger_level(log_level, logger) # EVAC+ network load self.model = init_model() self.fetch_default_weights()
[docs] def fetch_default_weights(self): """ Load the model pre-training weights to use for the fitting. While the user can load different weights, the function is mainly intended for the class function 'predict'. """ fetch_model_weights_path = get_fnames(name="evac_default_tf_weights") print(f"fetched {fetch_model_weights_path}") self.load_model_weights(fetch_model_weights_path)
[docs] def load_model_weights(self, weights_path): """ Load the custom pre-training weights to use for the fitting. Parameters ---------- weights_path : str Path to the file containing the weights (hdf5, saved by tensorflow) """ try: self.model.load_weights(weights_path) except ValueError as e: raise ValueError( "Expected input for the provided model weights \ do not match the declared model" ) from e
def __predict(self, x_test): """ Internal prediction function Parameters ---------- x_test : np.ndarray (batch, 128, 128, 128, 1) Image should match the required shape of the model. Returns ------- np.ndarray (batch, ...) Predicted brain mask """ return self.model.predict(x_test)
[docs] @deprecated_params( "largest_area", new_name="finalize_mask", since="1.10", until="1.12" ) def predict( self, T1, affine, *, voxsize=(1, 1, 1), batch_size=None, return_affine=False, return_prob=False, finalize_mask=True, ): """ Wrapper function to facilitate prediction of larger dataset. Parameters ---------- T1 : np.ndarray or list of np.ndarray For a single image, input should be a 3D array. If multiple images, it should be a a list or tuple. affine : np.ndarray (4, 4) or (batch, 4, 4) or list of np.ndarrays with len of batch Affine matrix for the T1 image. Should have batch dimension if T1 has one. voxsize : np.ndarray or list or tuple, optional (3,) or (batch, 3) voxel size of the T1 image. batch_size : int, optional Number of images per prediction pass. Only available if data is provided with a batch dimension. Consider lowering it if you get an out of memory error. Increase it if you want it to be faster and have a lot of data. If None, batch_size will be set to 1 if the provided image has a batch dimension. return_affine : bool, optional Whether to return the affine matrix. Useful if the input was a file path. return_prob : bool, optional Whether to return the probability map instead of a binary mask. Useful for testing. finalize_mask : bool, optional Whether to remove potential holes or islands. Useful for solving minor errors. Returns ------- pred_output : np.ndarray (...) or (batch, ...) Predicted brain mask affine : np.ndarray (...) or (batch, ...) affine matrix of mask only if return_affine is True """ voxsize = np.array(voxsize) affine = np.array(affine) if isinstance(T1, (list, tuple)): dim = 4 T1 = np.array(T1) elif len(T1.shape) == 3: dim = 3 if batch_size is not None: logger.warning( "Batch size specified, but not used", "due to the input not having \ a batch dimension", ) T1 = np.expand_dims(T1, 0) affine = np.expand_dims(affine, 0) voxsize = np.expand_dims(voxsize, 0) else: raise ValueError( "T1 data should be a np.ndarray of dimension 3 or a list/tuple of it" ) if batch_size is None: batch_size = 1 input_data = np.zeros((128, 128, 128, len(T1))) affines = np.zeros((len(T1), 4, 4)) mid_shapes = np.zeros((len(T1), 3)).astype(int) offset_arrays = np.zeros((len(T1), 4, 4)).astype(int) scales = np.zeros(len(T1)) crop_vss = np.zeros((len(T1), 3, 2)) pad_vss = np.zeros((len(T1), 3, 2)) # Normalize the data. n_T1 = np.zeros(T1.shape) for i, T1_img in enumerate(T1): n_T1[i] = normalize(T1_img, new_min=0, new_max=1) t_img, t_affine, mid_shape, offset_array, scale, crop_vs, pad_vs = ( transform_img(n_T1[i], affine[i], voxsize=voxsize[i]) ) input_data[..., i] = t_img affines[i] = t_affine mid_shapes[i] = mid_shape offset_arrays[i] = offset_array scales[i] = scale crop_vss[i] = crop_vs pad_vss[i] = pad_vs # Prediction stage prediction = np.zeros((len(T1), 128, 128, 128), dtype=np.float32) for batch_idx in range(batch_size, len(T1) + 1, batch_size): batch = input_data[..., batch_idx - batch_size : batch_idx] temp_input = prepare_img(batch) temp_pred = self.__predict(temp_input) prediction[:batch_idx] = temp_pred remainder = np.mod(len(T1), batch_size) if remainder != 0: temp_input = prepare_img(input_data[..., -remainder:]) temp_pred = self.__predict(temp_input) prediction[-remainder:] = temp_pred output_mask = [] for i in range(len(T1)): output = recover_img( prediction[i], affines[i], mid_shapes[i], n_T1[i].shape, offset_arrays[i], voxsize=voxsize[i], scale=scales[i], crop_vs=crop_vss[i], pad_vs=pad_vss[i], ) if not return_prob: output = np.where(output >= 0.5, 1, 0) if finalize_mask: output = remove_holes_and_islands(output, slice_wise=True) output_mask.append(output) if dim == 3: output_mask = output_mask[0] affine = affine[0] output_mask = np.array(output_mask) affine = np.array(affine) if return_affine: return output_mask, affine else: return output_mask