"""
Note
----
This file is a pytorch adapted version from the
sources of the SynthSeg project - https://github.com/BBillot/SynthSeg.
All weights and model artchitecture are original from the SynthSeg project.
It remains licensed as the rest of SynthSeg
(Apache 2.0 license as of January 2026).
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
# See LICENSE.txt file distributed along with the SynthSeg package for the
# copyright and license terms.
#
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""
import numpy as np
from scipy.ndimage import gaussian_filter
from dipy.data import get_fnames
from dipy.nn.utils import (
normalize,
recover_img,
set_logger_level,
transform_img,
)
from dipy.segment.utils import remove_holes_and_islands
from dipy.testing.decorators import doctest_skip_parser
from dipy.utils.logging import logger
from dipy.utils.optpkg import optional_package
torch, have_torch, _ = optional_package("torch", min_version="2.2.0")
if have_torch:
from torch.nn import (
BatchNorm3d,
Conv3d,
MaxPool3d,
Module,
ModuleList,
Softmax,
Upsample,
)
import torch.nn.functional as F
else:
class Module:
pass
[docs]
class Conv3dELU(Module):
"""
Mimics TensorFlow Conv3D + ELU fused behavior.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
dilation=1,
bias=True,
):
super().__init__()
self.conv = Conv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
bias=bias,
)
[docs]
def forward(self, x):
"""Forward pass of the Conv3dELU layer.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor after convolution and ELU activation.
"""
x = self.conv(x)
x = F.elu(x)
return x
[docs]
class Block(Module):
"""Building block for the SynthSeg model.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
kernel_size : int
Size of the convolutional kernel.
padding : int
Padding of the convolution.
n_layers : int
Number of convolutional layers in the block.
layer_type : str
Type of the block: 'down', 'down2', or 'up'.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding,
n_layers,
layer_type,
):
super(Block, self).__init__()
self.n_layers = n_layers
self.layer_list = ModuleList()
self.layer_type = layer_type
cur_channel = in_channels
if self.layer_type == "up":
self.layer_list.append(Upsample(scale_factor=2))
cur_channel = int(cur_channel * 1.5)
for _ in range(n_layers):
self.layer_list.append(
Conv3dELU(
cur_channel,
out_channels,
kernel_size,
padding=padding,
)
)
cur_channel = out_channels
self.layer_list.append(BatchNorm3d(out_channels, eps=1e-3, momentum=0.01))
if self.layer_type == "down":
self.layer_list.append(MaxPool3d(2, stride=2))
[docs]
def forward(self, input, passed):
"""Forward pass of the Block.
Parameters
----------
input : torch.Tensor
Input tensor.
passed : torch.Tensor
Tensor from skipped connection.
Returns
-------
x : torch.Tensor
Output of the convolutional layers.
skip : torch.Tensor, optional
Output of the convolutional layers before downsampling.
Only returned if layer_type is 'down'.
"""
x = input
for l_idx, layer in enumerate(self.layer_list):
x = layer(x)
if l_idx == 0 and self.layer_type == "up":
x = torch.cat([passed, x], dim=1)
if l_idx == self.n_layers - 1 and self.layer_type == "down":
skip = x
if self.layer_type == "down":
return x, skip
return x
[docs]
class Model(Module):
"""The Model class for the SynthSeg architecture.
Parameters
----------
model_scale : int, optional
The scale of the model.
n_levels : int, optional
Number of levels in the U-Net.
output_channels : int, optional
Number of output channels.
"""
def __init__(self, *, model_scale=24, n_levels=5, output_channels=33):
super(Model, self).__init__()
self.block_list = ModuleList()
self.channels = [model_scale * (2**i) for i in range(n_levels)]
self.channels_rev = [
model_scale * (2 ** (n_levels - i - 1)) for i in range(n_levels)
]
self.channels.insert(0, 1)
self.channels_rev.append(output_channels)
self.model_scale = model_scale
self.n_levels = n_levels
# Block structure
for level in range(n_levels - 1):
block = Block(
self.channels[level],
self.channels[level + 1],
kernel_size=3,
padding=1,
n_layers=2,
layer_type="down",
)
self.block_list.append(block)
level = n_levels - 1
block = Block(
self.channels[level],
self.channels[level + 1],
kernel_size=3,
padding=1,
n_layers=2,
layer_type="down2",
)
self.block_list.append(block)
for level in range(n_levels - 1):
block = Block(
self.channels_rev[level],
self.channels_rev[level + 1],
kernel_size=3,
padding=1,
n_layers=2,
layer_type="up",
)
self.block_list.append(block)
self.conv_pred = Conv3d(model_scale, 33, 1, padding=0)
self.softmax = Softmax(dim=1)
[docs]
def forward(self, inputs):
"""Forward pass of the SynthSeg model.
Parameters
----------
inputs : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Predicted brain segmentation probability map.
"""
skip_list = []
x = inputs
for block in self.block_list[: self.n_levels - 1]:
x, skip = block(x, None)
skip_list.append(skip)
x = self.block_list[self.n_levels - 1](x, None)
for idx, block in enumerate(self.block_list[self.n_levels :]):
passed = skip_list[-(idx + 1)]
x = block(x, passed)
x = self.conv_pred(x)
output = self.softmax(x)
return output
[docs]
class SynthSeg:
"""This class is intended for the SynthSeg model.
The SynthSeg model :footcite:p:`Billot2023` is a deep learning neural network for
brain segmentation. It uses a U-net architecture and was trained on synthetic
images generated from label maps. The model is robust to variations in
contrast and resolution, making it suitable for segmenting a wide range of
brain scans.
Note that we are not saving any PVE maps here, only the hard segmentation
due to the size of the output probability maps.
References
----------
.. footbibliography::
"""
@doctest_skip_parser
def __init__(self, *, verbose=False, use_cuda=False):
"""Model initialization
Parameters
----------
verbose : bool, optional
Whether to show information about the processing.
use_cuda : bool, optional
Whether to use GPU for processing.
If False or no CUDA is detected, CPU will be used.
"""
if not have_torch:
raise torch()
log_level = "INFO" if verbose else "CRITICAL"
set_logger_level(log_level, logger)
self.model = self.init_model()
self.model.eval()
if use_cuda:
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
logger.warning("CUDA requested but not found, switching to CPU")
self.device = torch.device("cpu")
else:
self.device = torch.device("cpu")
self.model = self.model.to(self.device)
self.fetch_default_weights()
self.labels_segmentation = np.array(
[
0,
2,
3,
4,
5,
7,
8,
10,
11,
12,
13,
14,
15,
16,
17,
18,
24,
26,
28,
41,
42,
43,
44,
46,
47,
49,
50,
51,
52,
53,
54,
58,
60,
]
)
self.topological_classes = np.array(
[
0,
4,
4,
4,
4,
5,
5,
6,
7,
8,
9,
1,
2,
3,
10,
11,
0,
12,
13,
14,
14,
14,
14,
15,
15,
16,
17,
18,
19,
20,
21,
22,
23,
]
)
self.flip_indices = np.array(
[
0,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
11,
12,
13,
29,
30,
16,
31,
32,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
14,
15,
17,
18,
]
)
self.label_dict = {
0: "background",
2: "left cerebral white matter",
3: "left cerebral cortex",
4: "left lateral ventricle",
5: "left inferior lateral ventricle",
7: "left cerebellum white matter",
8: "left cerebellum cortex",
10: "left thalamus",
11: "left caudate",
12: "left putamen",
13: "left pallidum",
14: "3rd ventricle",
15: "4th ventricle",
16: "brain-stem",
17: "left hippocampus",
18: "left amygdala",
24: "CSF",
26: "left accumbens area",
28: "left ventral DC",
41: "right cerebral white matter",
42: "right cerebral cortex",
43: "right lateral ventricle",
44: "right inferior lateral ventricle",
46: "right cerebellum white matter",
47: "right cerebellum cortex",
49: "right thalamus",
50: "right caudate",
51: "right putamen",
52: "right pallidum",
53: "right hippocampus",
54: "right amygdala",
58: "right accumbens area",
60: "right ventral DC",
}
[docs]
def init_model(self, model_scale=24, n_levels=5, output_channels=33):
"""Initialize the SynthSeg model.
Parameters
----------
model_scale : int, optional
The scale of the model.
n_levels : int, optional
Number of levels in the U-Net.
output_channels : int, optional
Number of output channels.
Returns
-------
Model
Initialized SynthSeg model.
"""
return Model(
model_scale=model_scale, n_levels=n_levels, output_channels=output_channels
)
[docs]
def fetch_default_weights(self):
"""Load the model pre-training weights to use for the fitting.
While the user can load different weights, the function
is mainly intended for the class function 'predict'.
"""
fetch_model_weights_path = get_fnames(name="synthseg_torch_weights")
self.load_model_weights(fetch_model_weights_path)
[docs]
def load_model_weights(self, weights_path):
"""Load the custom pre-training weights to use for the fitting.
Parameters
----------
weights_path : str
Path to the file containing the weights (pth, saved by Pytorch)
"""
try:
self.model.load_state_dict(
torch.load(
weights_path,
weights_only=True,
map_location=self.device,
)
)
self.model.eval()
except ValueError as e:
raise ValueError(
"Expected input for the provided model weights \
do not match the declared model"
) from e
def _flip_img_indices(self, img):
"""Flip the label indices for a flipped image.
Parameters
----------
img : np.ndarray
The image with label indices.
Returns
-------
np.ndarray
The image with flipped label indices.
"""
new_img = np.zeros_like(img)
for l_idx, flip_lbl in enumerate(self.flip_indices):
if l_idx != 0:
new_img = np.where(img == l_idx, flip_lbl, new_img)
return new_img
def _prepare_img(self, img):
"""Prepare the image for model input.
Parameters
----------
img : np.ndarray
Input image.
Returns
-------
torch.Tensor
Image tensor ready for the model.
"""
img = np.expand_dims(img, 1) # add channel dimension
img_tensor = torch.tensor(img, dtype=torch.float32).to(self.device)
return img_tensor
def __predict(self, x_test):
"""Internal prediction function
Parameters
----------
x_test : list of np.ndarray
Image should match the required shape of the model.
Returns
-------
np.ndarray (batch, ...)
Predicted brain labels
"""
return self.model(x_test).detach().numpy()
[docs]
def predict(
self,
T1,
affine,
*,
batch_size=None,
return_prob=False,
):
"""Wrapper function to facilitate prediction of larger dataset.
Parameters
----------
T1 : np.ndarray or list of np.ndarray
For a single image, input should be a 3D array.
If multiple images, it should be a a list or tuple.
affine : np.ndarray (4, 4) or (batch, 4, 4)
or list of np.ndarrays with len of batch
Affine matrix for the T1 image. Should have
batch dimension if T1 has one.
batch_size : int, optional
Number of images per prediction pass. Only available if data
is provided with a batch dimension.
Consider lowering it if you get an out of memory error.
Increase it if you want it to be faster and have a lot of data.
If None, batch_size will be set to 1 if the provided image
has a batch dimension.
return_prob : bool, optional
Whether to return the probability map instead of a
label map. Useful for testing.
Returns
-------
pred_output : np.ndarray (...) or (batch, ...)
Predicted brain labels. If return_prob is True, it will be a
probability map instead of a label map.
label_dict : dict
Dictionary mapping label indices to anatomical structure names.
Only if return_prob is False.
mask : np.ndarray (...) or (batch, ...)
Predicted brain mask.
Only if return_prob is False.
"""
affine = np.array(affine)
if isinstance(T1, (list, tuple)):
dim = 4
T1 = np.array(T1)
elif len(T1.shape) == 3:
dim = 3
if batch_size is not None:
logger.warning(
"Batch size specified, but not used",
"due to the input not having \
a batch dimension",
)
T1 = np.expand_dims(T1, 0)
affine = np.expand_dims(affine, 0)
else:
raise ValueError(
"Input data should be a np.ndarray of dimension 3 or a list/tuple of it"
)
if batch_size is None:
batch_size = 1
input_data = np.zeros((len(T1), 192, 192, 192))
params_list = []
# Normalize the data.
ori_shape = T1.shape[1:]
for i, T1_img in enumerate(T1):
t_img, params = transform_img(
T1_img,
affine[i],
target_voxsize=(1.0, 1.0, 1.0),
final_size=(192, 192, 192),
order=3,
)
min_v, max_v = np.percentile(t_img, (0.5, 99.5))
t_img = normalize(t_img, min_v=min_v, max_v=max_v, new_min=0, new_max=1)
input_data[i] = t_img
params_list.append(params)
# Prediction stage
prediction = np.zeros((len(T1), 192, 192, 192, 33), dtype=np.float32)
for batch_idx in range(batch_size, len(T1) + 1, batch_size):
batch = input_data[batch_idx - batch_size : batch_idx]
batch_start = batch_idx - batch_size
temp_input = self._prepare_img(batch)
temp_pred = self.__predict(temp_input)
temp_pred = gaussian_filter(temp_pred, (0, 0.5, 0.5, 0.5, 0))
prediction[batch_start:batch_idx] = np.moveaxis(temp_pred, 1, -1)
temp_input = self._prepare_img(np.flip(batch, axis=1).copy())
temp_pred = self.__predict(temp_input)
temp_pred = gaussian_filter(temp_pred, (0, 0.5, 0.5, 0.5, 0))
temp_pred = np.flip(temp_pred, axis=2)
temp_pred = np.stack(
[
temp_pred[:, self.flip_indices[i]]
for i in range(len(self.flip_indices))
],
axis=1,
)
prediction[batch_start:batch_idx] += np.moveaxis(temp_pred, 1, -1)
prediction[batch_start:batch_idx] /= 2
remainder = np.mod(len(T1), batch_size)
if remainder != 0:
batch = input_data[-remainder:]
temp_input = self._prepare_img(batch)
temp_pred = self.__predict(temp_input)
temp_pred = gaussian_filter(temp_pred, (0, 0.5, 0.5, 0.5, 0))
prediction[-remainder:] = np.moveaxis(temp_pred, 1, -1)
temp_input = self._prepare_img(np.flip(batch, axis=1).copy())
temp_pred = self.__predict(temp_input)
temp_pred = gaussian_filter(temp_pred, (0, 0.5, 0.5, 0.5, 0))
temp_pred = np.flip(temp_pred, axis=1)
temp_pred = np.stack(
[
temp_pred[:, self.flip_indices[i]]
for i in range(len(self.flip_indices))
],
axis=1,
)
prediction[-remainder:] += np.moveaxis(temp_pred, 1, -1)
prediction[-remainder:] /= 2
if return_prob:
labels = np.zeros((len(T1),) + (192, 192, 192, 33)).astype(np.float32)
else:
labels = np.zeros((len(T1),) + ori_shape).astype(np.int32)
masks = np.zeros((len(T1),) + ori_shape)
for i in range(len(T1)):
output = prediction[i]
tmp_post_patch_seg = output[..., 1:]
post_patch_seg_mask = np.sum(tmp_post_patch_seg, axis=-1) > 0.25
post_patch_seg_mask = remove_holes_and_islands(
post_patch_seg_mask, remove_holes=False
).astype(bool)
output[..., 1:] = output[..., 1:] * np.stack(
[post_patch_seg_mask] * 32, axis=-1
)
post_patch_seg_mask = output > 0.25
for topology_class in np.unique(self.topological_classes)[1:]:
tmp_topology_indices = np.where(
self.topological_classes == topology_class
)[0]
tmp_mask = np.any(
post_patch_seg_mask[..., tmp_topology_indices], axis=-1
)
tmp_mask = remove_holes_and_islands(
tmp_mask, remove_holes=False
).astype(bool)
for idx in tmp_topology_indices:
output[..., idx] *= tmp_mask
output /= np.sum(output, axis=-1)[..., np.newaxis]
if return_prob:
labels[i] = output.astype("float32")
continue
else:
temp = self.labels_segmentation[
output.argmax(-1).astype("int32")
].astype("int32")
temp = recover_img(temp, params_list[i], order=0)
labels[i] = np.round(temp).astype(np.int32)
masks[i] = (labels[i] > 0).astype(np.int32)
if dim == 3:
labels = labels[0]
if not return_prob:
masks = masks[0]
if return_prob:
return labels
return labels, self.label_dict, masks