#!/usr/bin/python
"""
Class and helper functions for fitting the EVAC+ model.
"""
import logging
import numpy as np
from dipy.align.reslice import reslice
from dipy.data import get_fnames
from dipy.nn.utils import (
normalize,
recover_img,
set_logger_level,
transform_img,
)
from dipy.segment.utils import remove_holes_and_islands
from dipy.testing.decorators import doctest_skip_parser
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package
torch, have_torch, _ = optional_package("torch", min_version="2.2.0")
if have_torch:
from torch.nn import (
Conv3d,
ConvTranspose3d,
Dropout3d,
LayerNorm,
Module,
ModuleList,
ReLU,
Softmax,
)
else:
class Module:
pass
logging.warning(
"This model requires Pytorch.\
Please install these packages using \
pip."
)
logging.basicConfig()
logger = logging.getLogger("EVAC+")
[docs]
def prepare_img(image):
"""
Function to prepare image for model input
Specific to EVAC+
Parameters
----------
image : np.ndarray
Input image
Returns
-------
input_data : dict
"""
input1 = np.moveaxis(image, -1, 0)
input1 = np.expand_dims(input1, 1)
input2, _ = reslice(image, np.eye(4), (1, 1, 1), (2, 2, 2))
input2 = np.moveaxis(input2, -1, 0)
input2 = np.expand_dims(input2, 1)
input3, _ = reslice(image, np.eye(4), (1, 1, 1), (4, 4, 4))
input3 = np.moveaxis(input3, -1, 0)
input3 = np.expand_dims(input3, 1)
input4, _ = reslice(image, np.eye(4), (1, 1, 1), (8, 8, 8))
input4 = np.moveaxis(input4, -1, 0)
input4 = np.expand_dims(input4, 1)
input5, _ = reslice(image, np.eye(4), (1, 1, 1), (16, 16, 16))
input5 = np.moveaxis(input5, -1, 0)
input5 = np.expand_dims(input5, 1)
input_data = [
torch.from_numpy(input1).float(),
torch.from_numpy(input2).float(),
torch.from_numpy(input3).float(),
torch.from_numpy(input4).float(),
torch.from_numpy(input5).float(),
]
return input_data
[docs]
class MoveDimLayer(Module):
def __init__(self, source_dim, dest_dim):
super(MoveDimLayer, self).__init__()
self.source_dim = source_dim
self.dest_dim = dest_dim
[docs]
def forward(self, x):
return torch.movedim(x, self.source_dim, self.dest_dim)
[docs]
class ChannelSum(Module):
def __init__(self):
super(ChannelSum, self).__init__()
[docs]
def forward(self, inputs):
return torch.sum(inputs, dim=1, keepdim=True)
[docs]
class Add(Module):
def __init__(self):
super(Add, self).__init__()
[docs]
def forward(self, x, passed):
return x + passed
[docs]
class Block(Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
strides,
padding,
drop_r,
n_layers,
*,
passed_channel=1,
layer_type="down",
):
super(Block, self).__init__()
self.n_layers = n_layers
self.layer_list = ModuleList()
self.layer_list2 = ModuleList()
cur_channel = in_channels
for _ in range(n_layers):
self.layer_list.append(
Conv3d(
cur_channel,
out_channels,
kernel_size,
stride=strides,
padding=padding,
)
)
cur_channel = out_channels
self.layer_list.append(Dropout3d(drop_r))
self.layer_list.append(MoveDimLayer(1, -1))
self.layer_list.append(LayerNorm(out_channels))
self.layer_list.append(MoveDimLayer(-1, 1))
self.layer_list.append(ReLU())
if layer_type == "down":
self.layer_list2.append(Conv3d(in_channels, 1, 2, stride=2, padding=0))
self.layer_list2.append(ReLU())
elif layer_type == "up":
self.layer_list2.append(
ConvTranspose3d(passed_channel, 1, 2, stride=2, padding=0)
)
self.layer_list2.append(ReLU())
self.channel_sum = ChannelSum()
self.add = Add()
[docs]
def forward(self, input, passed):
x = input
for layer in self.layer_list:
x = layer(x)
x = self.channel_sum(x)
fwd = self.add(x, passed)
x = fwd
for layer in self.layer_list2:
x = layer(x)
return fwd, x
[docs]
class Model(Module):
def __init__(self, model_scale=16):
super(Model, self).__init__()
# Block structure
self.block1 = Block(
1, model_scale, kernel_size=5, strides=1, padding=2, drop_r=0.2, n_layers=1
)
self.block2 = Block(
2,
model_scale * 2,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=2,
)
self.block3 = Block(
2,
model_scale * 4,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=3,
)
self.block4 = Block(
2,
model_scale * 8,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=3,
)
self.block5 = Block(
2,
model_scale * 16,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=3,
passed_channel=2,
layer_type="up",
)
# Upsample/decoder blocks
self.up_block1 = Block(
3,
model_scale * 8,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=3,
passed_channel=1,
layer_type="up",
)
self.up_block2 = Block(
3,
model_scale * 4,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=3,
passed_channel=1,
layer_type="up",
)
self.up_block3 = Block(
3,
model_scale * 2,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=2,
passed_channel=1,
layer_type="up",
)
self.up_block4 = Block(
2,
model_scale,
kernel_size=5,
strides=1,
padding=2,
drop_r=0.5,
n_layers=1,
passed_channel=1,
layer_type="none",
)
self.conv_pred = Conv3d(1, 2, 1, padding=0)
self.softmax = Softmax(dim=1)
[docs]
def forward(self, inputs, raw_input_2, raw_input_3, raw_input_4, raw_input_5):
fwd1, x = self.block1(inputs, inputs)
x = torch.cat([x, raw_input_2], dim=1)
fwd2, x = self.block2(x, x)
x = torch.cat([x, raw_input_3], dim=1)
fwd3, x = self.block3(x, x)
x = torch.cat([x, raw_input_4], dim=1)
fwd4, x = self.block4(x, x)
x = torch.cat([x, raw_input_5], dim=1)
# Decoding path
_, up = self.block5(x, x)
x = torch.cat([fwd4, up], dim=1)
_, up = self.up_block1(x, up)
x = torch.cat([fwd3, up], dim=1)
_, up = self.up_block2(x, up)
x = torch.cat([fwd2, up], dim=1)
_, up = self.up_block3(x, up)
x = torch.cat([fwd1, up], dim=1)
_, pred = self.up_block4(x, up)
pred = self.conv_pred(pred)
output = self.softmax(pred)
return output
[docs]
class EVACPlus:
"""
This class is intended for the EVAC+ model.
The EVAC+ model :footcite:p:`Park2024` is a deep learning neural network for
brain extraction. It uses a V-net architecture combined with
multi-resolution input data, an additional conditional random field (CRF)
recurrent layer and supplementary Dice loss term for this recurrent layer.
References
----------
.. footbibliography::
"""
@doctest_skip_parser
def __init__(self, *, verbose=False):
"""
The model was pre-trained for usage on
brain extraction of T1 images.
This model is designed to take as input
a T1 weighted image.
Parameters
----------
verbose : bool, optional
Whether to show information about the processing.
"""
if not have_torch:
raise torch()
log_level = "INFO" if verbose else "CRITICAL"
set_logger_level(log_level, logger)
# EVAC+ network load
self.model = self.init_model()
self.fetch_default_weights()
[docs]
def init_model(self, model_scale=16):
return Model(model_scale)
[docs]
def fetch_default_weights(self):
"""
Load the model pre-training weights to use for the fitting.
While the user can load different weights, the function
is mainly intended for the class function 'predict'.
"""
fetch_model_weights_path = get_fnames(name="evac_default_torch_weights")
self.load_model_weights(fetch_model_weights_path)
[docs]
def load_model_weights(self, weights_path):
"""
Load the custom pre-training weights to use for the fitting.
Parameters
----------
weights_path : str
Path to the file containing the weights (pth, saved by Pytorch)
"""
try:
self.model.load_state_dict(torch.load(weights_path, weights_only=True))
self.model.eval()
except ValueError as e:
raise ValueError(
"Expected input for the provided model weights \
do not match the declared model"
) from e
def __predict(self, x_test):
"""
Internal prediction function
Parameters
----------
x_test : list of np.ndarray
Image should match the required shape of the model.
Returns
-------
np.ndarray (batch, ...)
Predicted brain mask
"""
return self.model(*x_test)[:, 0].detach().numpy()
[docs]
@deprecated_params(
"largest_area", new_name="finalize_mask", since="1.10", until="1.12"
)
def predict(
self,
T1,
affine,
*,
voxsize=(1, 1, 1),
batch_size=None,
return_affine=False,
return_prob=False,
finalize_mask=True,
):
"""
Wrapper function to facilitate prediction of larger dataset.
Parameters
----------
T1 : np.ndarray or list of np.ndarray
For a single image, input should be a 3D array.
If multiple images, it should be a a list or tuple.
affine : np.ndarray (4, 4) or (batch, 4, 4)
or list of np.ndarrays with len of batch
Affine matrix for the T1 image. Should have
batch dimension if T1 has one.
voxsize : np.ndarray or list or tuple, optional
(3,) or (batch, 3)
voxel size of the T1 image.
batch_size : int, optional
Number of images per prediction pass. Only available if data
is provided with a batch dimension.
Consider lowering it if you get an out of memory error.
Increase it if you want it to be faster and have a lot of data.
If None, batch_size will be set to 1 if the provided image
has a batch dimension.
return_affine : bool, optional
Whether to return the affine matrix. Useful if the input was a
file path.
return_prob : bool, optional
Whether to return the probability map instead of a
binary mask. Useful for testing.
finalize_mask : bool, optional
Whether to remove potential holes or islands.
Useful for solving minor errors.
Returns
-------
pred_output : np.ndarray (...) or (batch, ...)
Predicted brain mask
affine : np.ndarray (...) or (batch, ...)
affine matrix of mask
only if return_affine is True
"""
voxsize = np.array(voxsize)
affine = np.array(affine)
if isinstance(T1, (list, tuple)):
dim = 4
T1 = np.array(T1)
elif len(T1.shape) == 3:
dim = 3
if batch_size is not None:
logger.warning(
"Batch size specified, but not used",
"due to the input not having \
a batch dimension",
)
T1 = np.expand_dims(T1, 0)
affine = np.expand_dims(affine, 0)
voxsize = np.expand_dims(voxsize, 0)
else:
raise ValueError(
"T1 data should be a np.ndarray of dimension 3 or a list/tuple of it"
)
if batch_size is None:
batch_size = 1
input_data = np.zeros((128, 128, 128, len(T1)))
affines = np.zeros((len(T1), 4, 4))
mid_shapes = np.zeros((len(T1), 3)).astype(int)
offset_arrays = np.zeros((len(T1), 4, 4)).astype(int)
scales = np.zeros(len(T1))
crop_vss = np.zeros((len(T1), 3, 2))
pad_vss = np.zeros((len(T1), 3, 2))
# Normalize the data.
n_T1 = np.zeros(T1.shape)
for i, T1_img in enumerate(T1):
n_T1[i] = normalize(T1_img, new_min=0, new_max=1)
t_img, t_affine, mid_shape, offset_array, scale, crop_vs, pad_vs = (
transform_img(n_T1[i], affine[i], voxsize=voxsize[i])
)
input_data[..., i] = t_img
affines[i] = t_affine
mid_shapes[i] = mid_shape
offset_arrays[i] = offset_array
scales[i] = scale
crop_vss[i] = crop_vs
pad_vss[i] = pad_vs
# Prediction stage
prediction = np.zeros((len(T1), 128, 128, 128), dtype=np.float32)
for batch_idx in range(batch_size, len(T1) + 1, batch_size):
batch = input_data[..., batch_idx - batch_size : batch_idx]
temp_input = prepare_img(batch)
temp_pred = self.__predict(temp_input)
prediction[:batch_idx] = temp_pred
remainder = np.mod(len(T1), batch_size)
if remainder != 0:
temp_input = prepare_img(input_data[..., -remainder:])
temp_pred = self.__predict(temp_input)
prediction[-remainder:] = temp_pred
output_mask = []
for i in range(len(T1)):
output = recover_img(
prediction[i],
affines[i],
mid_shapes[i],
n_T1[i].shape,
offset_arrays[i],
voxsize=voxsize[i],
scale=scales[i],
crop_vs=crop_vss[i],
pad_vs=pad_vss[i],
)
if not return_prob:
output = np.where(output >= 0.5, 1, 0)
if finalize_mask:
output = remove_holes_and_islands(output, slice_wise=True)
output_mask.append(output)
if dim == 3:
output_mask = output_mask[0]
affine = affine[0]
output_mask = np.array(output_mask)
affine = np.array(affine)
if return_affine:
return output_mask, affine
else:
return output_mask