Source code for dipy.nn.torch.histo_resdnn

#!/usr/bin/python
"""
Class and helper functions for fitting the Histological ResDNN model.
"""

import logging

import numpy as np

from dipy.core.gradients import get_bval_indices, unique_bvals_magnitude
from dipy.core.sphere import HemiSphere
from dipy.data import get_fnames, get_sphere
from dipy.nn.utils import set_logger_level
from dipy.reconst.shm import sf_to_sh, sh_to_sf, sph_harm_ind_list
from dipy.testing.decorators import doctest_skip_parser
from dipy.utils.optpkg import optional_package

torch, have_torch, _ = optional_package("torch", min_version="2.2.0")
if have_torch:
    from torch.nn import Linear, Module
else:

    class Module:
        pass

    logging.warning(
        "This model requires Pytorch.\
                    Please install these packages using \
                    pip."
    )

logging.basicConfig()
logger = logging.getLogger("histo_resdnn")


[docs] class DenseModel(Module): def __init__(self, sh_size, num_hidden): super(DenseModel, self).__init__() self.fc1 = Linear(sh_size, 400) self.fc2 = Linear(400, num_hidden) self.fc3 = Linear(num_hidden, 200) self.fc4 = Linear(200, num_hidden) self.fc5 = Linear(num_hidden, 200) self.fc6 = Linear(200, num_hidden)
[docs] def forward(self, x): x1 = torch.relu(self.fc1(x)) x2 = torch.relu(self.fc2(x1)) x3 = torch.relu(self.fc3(x2)) x4 = self.fc4(x3) # Adding x2 and x4 res_add = x2 + x4 x5 = torch.relu(self.fc5(res_add)) x6 = self.fc6(x5) return x6
[docs] class HistoResDNN: """ This class is intended for the ResDNN Histology Network model. ResDNN :footcite:p:`Nath2019` is a deep neural network that employs residual blocks deep neural network to predict ground truth SH coefficients from SH coefficients computed using DWI data. To this end, authors considered histology FOD-computed SH coefficients (obtained from ex vivo non-human primate acquisitions) as their ground truth, and the DWI-computed SH coefficients as their target. References ---------- .. footbibliography:: """ @doctest_skip_parser def __init__(self, *, sh_order_max=8, basis_type="tournier07", verbose=False): r""" The model was re-trained for usage with a different basis function ('tournier07') like the proposed model in :footcite:p:`Nath2019`. To obtain the pre-trained model, use:: >>> resdnn_model = HistoResDNN() # skip if not have_torch >>> fetch_model_weights_path = get_fnames(name='histo_resdnn_torch_weights') # skip if not have_torch >>> resdnn_model.load_model_weights(fetch_model_weights_path) # skip if not have_torch This model is designed to take as input raw DWI signal on a sphere (ODF) represented as SH of order 8 in the tournier basis and predict fODF of order 8 in the tournier basis. Effectively, this model is mimicking a CSD fit. Parameters ---------- sh_order_max : int, optional Maximum SH order (l) in the SH fit. For ``sh_order_max``, there will be ``(sh_order_max + 1) * (sh_order_max + 2) / 2`` SH coefficients for a symmetric basis. basis_type : {'tournier07', 'descoteaux07'}, optional ``tournier07`` (default) or ``descoteaux07``. verbose : bool, optional Whether to show information about the processing. References ---------- .. footbibliography:: """ # noqa: E501 if not have_torch: raise torch() self.sh_order_max = sh_order_max self.sh_size = len(sph_harm_ind_list(sh_order_max)[0]) self.basis_type = basis_type log_level = "INFO" if verbose else "CRITICAL" set_logger_level(log_level, logger) if self.basis_type != "tournier07": logger.warning( "Be careful, original weights were obtained " "from training on the tournier07 basis, " "unless you re-trained the network, do not " "change basis!" ) # ResDNN Network Flow num_hidden = self.sh_size self.model = DenseModel(self.sh_size, num_hidden).type(torch.float64)
[docs] def fetch_default_weights(self): """ Load the model pre-training weights to use for the fitting. Will not work if the declared SH_ORDER does not match the weights expected input. """ fetch_model_weights_path = get_fnames(name="histo_resdnn_torch_weights") self.load_model_weights(fetch_model_weights_path)
[docs] def load_model_weights(self, weights_path): """ Load the custom pre-training weights to use for the fitting. Will not work if the declared SH_ORDER does not match the weights expected input. The weights for a sh_order of 8 can be obtained via the function: get_fnames('histo_resdnn_torch_weights'). Parameters ---------- weights_path : str Path to the file containing the weights (pth, saved by Pytorch) """ try: self.model.load_state_dict(torch.load(weights_path, weights_only=True)) self.model.eval() except RuntimeError as e: raise ValueError( "Expected input for the provided model weights do not match the " f"declared model ({self.sh_size})" ) from e
def __predict(self, x_test): r""" Predict fODF (as SH) from input raw DWI signal (as SH) Parameters ---------- x_test : np.ndarray Array of size (N, M) where M is ``(sh_order_max + 1) * (sh_order_max + 2) / 2``. N should not be too big as to limit memory usage. Returns ------- np.ndarray (N, M) Predicted fODF (as SH) """ if x_test.shape[-1] != self.sh_size: raise ValueError( "Expected input for the provided model weights do not match the " f"declared model ({self.sh_size})" ) return self.model(torch.from_numpy(x_test)).detach().numpy()
[docs] def predict(self, data, gtab, *, mask=None, chunk_size=1000): """Wrapper function to facilitate prediction of larger dataset. The function will mask, normalize, split, predict and 're-assemble' the data as a volume. Parameters ---------- data : np.ndarray DWI signal in a 4D array gtab : GradientTable class instance The acquisition scheme matching the data (must contain at least one b0) mask : np.ndarray, optional Binary mask of the brain to avoid unnecessary computation and unreliable prediction outside the brain. Default: Compute prediction only for nonzero voxels (with at least one nonzero DWI value). Returns ------- pred_sh_coef : np.ndarray (x, y, z, M) Predicted fODF (as SH). The volume has matching shape to the input data, but with ``(sh_order_max + 1) * (sh_order_max + 2) / 2`` as a last dimension. """ if mask is None: logger.warning( "Mask should be provided to accelerate " "computation, and because predictions are " "not reliable outside of the brain." ) mask = np.sum(data, axis=-1) mask = mask.astype(bool) # Extract B0's and obtain a mean B0 b0_indices = gtab.b0s_mask if not len(b0_indices) > 0: raise ValueError("b0 must be present for DWI normalization.") logger.info(f"b0 indices found are: {np.argwhere(b0_indices).ravel()}") mean_b0 = np.mean(data[..., b0_indices], axis=-1) # Detect number of b-values and extract a single shell of DW-MRI Data unique_shells = np.sort(unique_bvals_magnitude(gtab.bvals)) logger.info(f"Number of b-values: {unique_shells}") # Extract DWI only dw_indices = get_bval_indices(gtab.bvals, unique_shells[1]) dw_data = data[..., dw_indices] dw_bvecs = gtab.bvecs[dw_indices, :] # Normalize the DW-MRI Data with the mean b0 (voxel-wise) norm_dw_data = np.zeros(dw_data.shape) for n in range(len(dw_indices)): norm_dw_data[..., n] = np.divide( dw_data[..., n], mean_b0, where=np.abs(mean_b0) > 0.000001 ) # Fit SH to the raw DWI signal h_sphere = HemiSphere(xyz=dw_bvecs) dw_sh_coef = sf_to_sh( norm_dw_data, h_sphere, smooth=0.0006, basis_type=self.basis_type, sh_order_max=self.sh_order_max, ) # Flatten and mask the data (N, SH_SIZE) to facilitate chunks ori_shape = dw_sh_coef.shape flat_dw_sh_coef = dw_sh_coef[mask > 0] flat_pred_sh_coef = np.zeros(flat_dw_sh_coef.shape) count = len(flat_dw_sh_coef) // chunk_size for i in range(count + 1): if i % 100 == 0 or i == count: logger.info(f"Chunk #{i} out of {count}") tmp_sh = self.__predict( flat_dw_sh_coef[i * chunk_size : (i + 1) * chunk_size] ) # Removing negative values from the SF sphere = get_sphere(name="repulsion724") tmp_sf = sh_to_sf( sh=tmp_sh, sphere=sphere, basis_type=self.basis_type, sh_order_max=self.sh_order_max, ) tmp_sf[tmp_sf < 0] = 0 tmp_sh = sf_to_sh( tmp_sf, sphere, smooth=0.0006, basis_type=self.basis_type, sh_order_max=self.sh_order_max, ) flat_pred_sh_coef[i * chunk_size : (i + 1) * chunk_size] = tmp_sh pred_sh_coef = np.zeros(ori_shape) pred_sh_coef[mask > 0] = flat_pred_sh_coef return pred_sh_coef