#!/usr/bin/python
"""
Class and helper functions for fitting the EVAC+ model.
"""
import logging
import numpy as np
from dipy.align.reslice import reslice
from dipy.data import get_fnames
from dipy.nn.utils import (
    normalize,
    recover_img,
    set_logger_level,
    transform_img,
)
from dipy.segment.utils import remove_holes_and_islands
from dipy.testing.decorators import doctest_skip_parser
from dipy.utils.deprecator import deprecated_params
from dipy.utils.optpkg import optional_package
torch, have_torch, _ = optional_package("torch", min_version="2.2.0")
if have_torch:
    from torch.nn import (
        Conv3d,
        ConvTranspose3d,
        Dropout3d,
        LayerNorm,
        Module,
        ModuleList,
        ReLU,
        Softmax,
    )
else:
    class Module:
        pass
    logging.warning(
        "This model requires Pytorch.\
                    Please install these packages using \
                    pip."
    )
logging.basicConfig()
logger = logging.getLogger("EVAC+")
[docs]
def prepare_img(image):
    """
    Function to prepare image for model input
    Specific to EVAC+
    Parameters
    ----------
    image : np.ndarray
        Input image
    Returns
    -------
    input_data : dict
    """
    input1 = np.moveaxis(image, -1, 0)
    input1 = np.expand_dims(input1, 1)
    input2, _ = reslice(image, np.eye(4), (1, 1, 1), (2, 2, 2))
    input2 = np.moveaxis(input2, -1, 0)
    input2 = np.expand_dims(input2, 1)
    input3, _ = reslice(image, np.eye(4), (1, 1, 1), (4, 4, 4))
    input3 = np.moveaxis(input3, -1, 0)
    input3 = np.expand_dims(input3, 1)
    input4, _ = reslice(image, np.eye(4), (1, 1, 1), (8, 8, 8))
    input4 = np.moveaxis(input4, -1, 0)
    input4 = np.expand_dims(input4, 1)
    input5, _ = reslice(image, np.eye(4), (1, 1, 1), (16, 16, 16))
    input5 = np.moveaxis(input5, -1, 0)
    input5 = np.expand_dims(input5, 1)
    input_data = [
        torch.from_numpy(input1).float(),
        torch.from_numpy(input2).float(),
        torch.from_numpy(input3).float(),
        torch.from_numpy(input4).float(),
        torch.from_numpy(input5).float(),
    ]
    return input_data 
[docs]
class MoveDimLayer(Module):
    def __init__(self, source_dim, dest_dim):
        super(MoveDimLayer, self).__init__()
        self.source_dim = source_dim
        self.dest_dim = dest_dim
[docs]
    def forward(self, x):
        return torch.movedim(x, self.source_dim, self.dest_dim) 
 
[docs]
class ChannelSum(Module):
    def __init__(self):
        super(ChannelSum, self).__init__()
[docs]
    def forward(self, inputs):
        return torch.sum(inputs, dim=1, keepdim=True) 
 
[docs]
class Add(Module):
    def __init__(self):
        super(Add, self).__init__()
[docs]
    def forward(self, x, passed):
        return x + passed 
 
[docs]
class Block(Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        strides,
        padding,
        drop_r,
        n_layers,
        *,
        passed_channel=1,
        layer_type="down",
    ):
        super(Block, self).__init__()
        self.n_layers = n_layers
        self.layer_list = ModuleList()
        self.layer_list2 = ModuleList()
        cur_channel = in_channels
        for _ in range(n_layers):
            self.layer_list.append(
                Conv3d(
                    cur_channel,
                    out_channels,
                    kernel_size,
                    stride=strides,
                    padding=padding,
                )
            )
            cur_channel = out_channels
            self.layer_list.append(Dropout3d(drop_r))
            self.layer_list.append(MoveDimLayer(1, -1))
            self.layer_list.append(LayerNorm(out_channels))
            self.layer_list.append(MoveDimLayer(-1, 1))
            self.layer_list.append(ReLU())
        if layer_type == "down":
            self.layer_list2.append(Conv3d(in_channels, 1, 2, stride=2, padding=0))
            self.layer_list2.append(ReLU())
        elif layer_type == "up":
            self.layer_list2.append(
                ConvTranspose3d(passed_channel, 1, 2, stride=2, padding=0)
            )
            self.layer_list2.append(ReLU())
        self.channel_sum = ChannelSum()
        self.add = Add()
[docs]
    def forward(self, input, passed):
        x = input
        for layer in self.layer_list:
            x = layer(x)
        x = self.channel_sum(x)
        fwd = self.add(x, passed)
        x = fwd
        for layer in self.layer_list2:
            x = layer(x)
        return fwd, x 
 
[docs]
class Model(Module):
    def __init__(self, model_scale=16):
        super(Model, self).__init__()
        # Block structure
        self.block1 = Block(
            1, model_scale, kernel_size=5, strides=1, padding=2, drop_r=0.2, n_layers=1
        )
        self.block2 = Block(
            2,
            model_scale * 2,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=2,
        )
        self.block3 = Block(
            2,
            model_scale * 4,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=3,
        )
        self.block4 = Block(
            2,
            model_scale * 8,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=3,
        )
        self.block5 = Block(
            2,
            model_scale * 16,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=3,
            passed_channel=2,
            layer_type="up",
        )
        # Upsample/decoder blocks
        self.up_block1 = Block(
            3,
            model_scale * 8,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=3,
            passed_channel=1,
            layer_type="up",
        )
        self.up_block2 = Block(
            3,
            model_scale * 4,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=3,
            passed_channel=1,
            layer_type="up",
        )
        self.up_block3 = Block(
            3,
            model_scale * 2,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=2,
            passed_channel=1,
            layer_type="up",
        )
        self.up_block4 = Block(
            2,
            model_scale,
            kernel_size=5,
            strides=1,
            padding=2,
            drop_r=0.5,
            n_layers=1,
            passed_channel=1,
            layer_type="none",
        )
        self.conv_pred = Conv3d(1, 2, 1, padding=0)
        self.softmax = Softmax(dim=1)
[docs]
    def forward(self, inputs, raw_input_2, raw_input_3, raw_input_4, raw_input_5):
        fwd1, x = self.block1(inputs, inputs)
        x = torch.cat([x, raw_input_2], dim=1)
        fwd2, x = self.block2(x, x)
        x = torch.cat([x, raw_input_3], dim=1)
        fwd3, x = self.block3(x, x)
        x = torch.cat([x, raw_input_4], dim=1)
        fwd4, x = self.block4(x, x)
        x = torch.cat([x, raw_input_5], dim=1)
        # Decoding path
        _, up = self.block5(x, x)
        x = torch.cat([fwd4, up], dim=1)
        _, up = self.up_block1(x, up)
        x = torch.cat([fwd3, up], dim=1)
        _, up = self.up_block2(x, up)
        x = torch.cat([fwd2, up], dim=1)
        _, up = self.up_block3(x, up)
        x = torch.cat([fwd1, up], dim=1)
        _, pred = self.up_block4(x, up)
        pred = self.conv_pred(pred)
        output = self.softmax(pred)
        return output 
 
[docs]
class EVACPlus:
    """
    This class is intended for the EVAC+ model.
    The EVAC+ model :footcite:p:`Park2024` is a deep learning neural network for
    brain extraction. It uses a V-net architecture combined with
    multi-resolution input data, an additional conditional random field (CRF)
    recurrent layer and supplementary Dice loss term for this recurrent layer.
    References
    ----------
    .. footbibliography::
    """
    @doctest_skip_parser
    def __init__(self, *, verbose=False):
        """
        The model was pre-trained for usage on
        brain extraction of T1 images.
        This model is designed to take as input
        a T1 weighted image.
        Parameters
        ----------
        verbose : bool, optional
            Whether to show information about the processing.
        """
        if not have_torch:
            raise torch()
        log_level = "INFO" if verbose else "CRITICAL"
        set_logger_level(log_level, logger)
        # EVAC+ network load
        self.model = self.init_model()
        self.fetch_default_weights()
[docs]
    def init_model(self, model_scale=16):
        return Model(model_scale) 
[docs]
    def fetch_default_weights(self):
        """
        Load the model pre-training weights to use for the fitting.
        While the user can load different weights, the function
        is mainly intended for the class function 'predict'.
        """
        fetch_model_weights_path = get_fnames(name="evac_default_torch_weights")
        self.load_model_weights(fetch_model_weights_path) 
[docs]
    def load_model_weights(self, weights_path):
        """
        Load the custom pre-training weights to use for the fitting.
        Parameters
        ----------
        weights_path : str
            Path to the file containing the weights (pth, saved by Pytorch)
        """
        try:
            self.model.load_state_dict(torch.load(weights_path, weights_only=True))
            self.model.eval()
        except ValueError as e:
            raise ValueError(
                "Expected input for the provided model weights \
                             do not match the declared model"
            ) from e 
    def __predict(self, x_test):
        """
        Internal prediction function
        Parameters
        ----------
        x_test : list of np.ndarray
            Image should match the required shape of the model.
        Returns
        -------
        np.ndarray (batch, ...)
            Predicted brain mask
        """
        return self.model(*x_test)[:, 0].detach().numpy()
[docs]
    @deprecated_params(
        "largest_area", new_name="finalize_mask", since="1.10", until="1.12"
    )
    def predict(
        self,
        T1,
        affine,
        *,
        voxsize=(1, 1, 1),
        batch_size=None,
        return_affine=False,
        return_prob=False,
        finalize_mask=True,
    ):
        """
        Wrapper function to facilitate prediction of larger dataset.
        Parameters
        ----------
        T1 : np.ndarray or list of np.ndarray
            For a single image, input should be a 3D array.
            If multiple images, it should be a a list or tuple.
        affine : np.ndarray (4, 4) or (batch, 4, 4)
            or list of np.ndarrays with len of batch
            Affine matrix for the T1 image. Should have
            batch dimension if T1 has one.
        voxsize : np.ndarray or list or tuple, optional
            (3,) or (batch, 3)
            voxel size of the T1 image.
        batch_size : int, optional
            Number of images per prediction pass. Only available if data
            is provided with a batch dimension.
            Consider lowering it if you get an out of memory error.
            Increase it if you want it to be faster and have a lot of data.
            If None, batch_size will be set to 1 if the provided image
            has a batch dimension.
        return_affine : bool, optional
            Whether to return the affine matrix. Useful if the input was a
            file path.
        return_prob : bool, optional
            Whether to return the probability map instead of a
            binary mask. Useful for testing.
        finalize_mask : bool, optional
            Whether to remove potential holes or islands.
            Useful for solving minor errors.
        Returns
        -------
        pred_output : np.ndarray (...) or (batch, ...)
            Predicted brain mask
        affine : np.ndarray (...) or (batch, ...)
            affine matrix of mask
            only if return_affine is True
        """
        voxsize = np.array(voxsize)
        affine = np.array(affine)
        if isinstance(T1, (list, tuple)):
            dim = 4
            T1 = np.array(T1)
        elif len(T1.shape) == 3:
            dim = 3
            if batch_size is not None:
                logger.warning(
                    "Batch size specified, but not used",
                    "due to the input not having \
                                a batch dimension",
                )
            T1 = np.expand_dims(T1, 0)
            affine = np.expand_dims(affine, 0)
            voxsize = np.expand_dims(voxsize, 0)
        else:
            raise ValueError(
                "T1 data should be a np.ndarray of dimension 3 or a list/tuple of it"
            )
        if batch_size is None:
            batch_size = 1
        input_data = np.zeros((128, 128, 128, len(T1)))
        affines = np.zeros((len(T1), 4, 4))
        mid_shapes = np.zeros((len(T1), 3)).astype(int)
        offset_arrays = np.zeros((len(T1), 4, 4)).astype(int)
        scales = np.zeros(len(T1))
        crop_vss = np.zeros((len(T1), 3, 2))
        pad_vss = np.zeros((len(T1), 3, 2))
        # Normalize the data.
        n_T1 = np.zeros(T1.shape)
        for i, T1_img in enumerate(T1):
            n_T1[i] = normalize(T1_img, new_min=0, new_max=1)
            t_img, t_affine, mid_shape, offset_array, scale, crop_vs, pad_vs = (
                transform_img(n_T1[i], affine[i], voxsize=voxsize[i])
            )
            input_data[..., i] = t_img
            affines[i] = t_affine
            mid_shapes[i] = mid_shape
            offset_arrays[i] = offset_array
            scales[i] = scale
            crop_vss[i] = crop_vs
            pad_vss[i] = pad_vs
        # Prediction stage
        prediction = np.zeros((len(T1), 128, 128, 128), dtype=np.float32)
        for batch_idx in range(batch_size, len(T1) + 1, batch_size):
            batch = input_data[..., batch_idx - batch_size : batch_idx]
            temp_input = prepare_img(batch)
            temp_pred = self.__predict(temp_input)
            prediction[:batch_idx] = temp_pred
        remainder = np.mod(len(T1), batch_size)
        if remainder != 0:
            temp_input = prepare_img(input_data[..., -remainder:])
            temp_pred = self.__predict(temp_input)
            prediction[-remainder:] = temp_pred
        output_mask = []
        for i in range(len(T1)):
            output = recover_img(
                prediction[i],
                affines[i],
                mid_shapes[i],
                n_T1[i].shape,
                offset_arrays[i],
                voxsize=voxsize[i],
                scale=scales[i],
                crop_vs=crop_vss[i],
                pad_vs=pad_vss[i],
            )
            if not return_prob:
                output = np.where(output >= 0.5, 1, 0)
                if finalize_mask:
                    output = remove_holes_and_islands(output, slice_wise=True)
            output_mask.append(output)
        if dim == 3:
            output_mask = output_mask[0]
            affine = affine[0]
        output_mask = np.array(output_mask)
        affine = np.array(affine)
        if return_affine:
            return output_mask, affine
        else:
            return output_mask