Source code for dipy.align.cpd

"""

Note
----

This file is copied (possibly with major modifications) from the
sources of the pycpd project - https://github.com/siavashk/pycpd.
It remains licensed as the rest of PyCPD (MIT license as of October 2010).

# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the PyCPD package for the
#   copyright and license terms.
#
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""

import numbers
from warnings import warn

import numpy as np

from dipy.testing.decorators import warning_for_keywords


[docs] def gaussian_kernel(X, beta, Y=None): if Y is None: Y = X diff = X[:, None, :] - Y[None, :, :] diff = np.square(diff) diff = np.sum(diff, 2) return np.exp(-diff / (2 * beta**2))
[docs] def low_rank_eigen(G, num_eig): """Calculate num_eig eigenvectors and eigenvalues of gaussian matrix G. Enables lower dimensional solving. """ S, Q = np.linalg.eigh(G) eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig]) Q = Q[:, eig_indices] # eigenvectors S = S[eig_indices] # eigenvalues. return Q, S
[docs] def initialize_sigma2(X, Y): """Initialize the variance (sigma2). Parameters ---------- X: numpy array NxD array of points for target. Y: numpy array MxD array of points for source. Returns ------- sigma2: float Initial variance. """ (N, D) = X.shape (M, _) = Y.shape diff = X[None, :, :] - Y[:, None, :] err = diff**2 return np.sum(err) / (D * M * N)
[docs] @warning_for_keywords() def lowrankQS(G, beta, num_eig, *, eig_fgt=False): """Calculate eigenvectors and eigenvalues of gaussian matrix G. !!! This function is a placeholder for implementing the fast gauss transform. It is not yet implemented. !!! Parameters ---------- G: numpy array Gaussian kernel matrix. beta: float Width of the Gaussian kernel. num_eig: int Number of eigenvectors to use in lowrank calculation of G eig_fgt: bool If True, use fast gauss transform method to speed up. """ # if we do not use FGT we construct affinity matrix G and find the # first eigenvectors/values directly if eig_fgt is False: S, Q = np.linalg.eigh(G) eig_indices = list(np.argsort(np.abs(S))[::-1][:num_eig]) Q = Q[:, eig_indices] # eigenvectors S = S[eig_indices] # eigenvalues. return Q, S elif eig_fgt is True: raise Exception("Fast Gauss Transform Not Implemented!")
[docs] class DeformableRegistration: """ Deformable point cloud registration. Attributes ---------- X: numpy array NxD array of target points. Y: numpy array MxD array of source points. TY: numpy array MxD array of transformed source points. sigma2: float (positive) Initial variance of the Gaussian mixture model. N: int Number of target points. M: int Number of source points. D: int Dimensionality of source and target points iteration: int The current iteration throughout registration. max_iterations: int Registration will terminate once the algorithm has taken this many iterations. tolerance: float (positive) Registration will terminate once the difference between consecutive objective function values falls within this tolerance. w: float (between 0 and 1) Contribution of the uniform distribution to account for outliers. Valid values span 0 (inclusive) and 1 (exclusive). q: float The objective function value that represents the misalignment between source and target point clouds. diff: float (positive) The absolute difference between the current and previous objective function values. P: numpy array MxN array of probabilities. P[m, n] represents the probability that the m-th source point corresponds to the n-th target point. Pt1: numpy array Nx1 column array. Multiplication result between the transpose of P and a column vector of all 1s. P1: numpy array Mx1 column array. Multiplication result between P and a column vector of all 1s. Np: float (positive) The sum of all elements in P. alpha: float (positive) Represents the trade-off between the goodness of maximum likelihoo fit and regularization. beta: float(positive) Width of the Gaussian kernel. low_rank: bool Whether to use low rank approximation. num_eig: int Number of eigenvectors to use in lowrank calculation. """ def __init__( self, X, Y, *args, sigma2=None, alpha=None, beta=None, low_rank=False, num_eig=100, max_iterations=None, tolerance=None, w=None, **kwargs, ): if not isinstance(X, np.ndarray) or X.ndim != 2: raise ValueError("The target point cloud (X) must be at a 2D numpy array.") if not isinstance(Y, np.ndarray) or Y.ndim != 2: raise ValueError("The source point cloud (Y) must be a 2D numpy array.") if X.shape[1] != Y.shape[1]: msg = "Both point clouds need to have the same number " msg += "of dimensions." raise ValueError(msg) if sigma2 is not None and ( not isinstance(sigma2, numbers.Number) or sigma2 <= 0 ): msg = f"Expected a positive value for sigma2 instead got: {sigma2}" raise ValueError(msg) if max_iterations is not None and ( not isinstance(max_iterations, numbers.Number) or max_iterations < 0 ): msg = "Expected a positive integer for max_iterations " msg += f"instead got: {max_iterations}" raise ValueError(msg) elif isinstance(max_iterations, numbers.Number) and not isinstance( max_iterations, int ): msg = "Received a non-integer value for max_iterations: " msg += f"{max_iterations}. Casting to integer." warn(msg, stacklevel=2) max_iterations = int(max_iterations) if tolerance is not None and ( not isinstance(tolerance, numbers.Number) or tolerance < 0 ): msg = "Expected a positive float for tolerance " msg += f"instead got: {tolerance}" raise ValueError(msg) if w is not None and (not isinstance(w, numbers.Number) or w < 0 or w >= 1): msg = "Expected a value between 0 (inclusive) and 1 (exclusive) " msg += f"for w instead got: {w}" raise ValueError(msg) self.X = X self.Y = Y self.TY = Y self.sigma2 = initialize_sigma2(X, Y) if sigma2 is None else sigma2 (self.N, self.D) = self.X.shape (self.M, _) = self.Y.shape self.tolerance = 0.001 if tolerance is None else tolerance self.w = 0.0 if w is None else w self.max_iterations = 100 if max_iterations is None else max_iterations self.iteration = 0 self.diff = np.inf self.q = np.inf self.P = np.zeros((self.M, self.N)) self.Pt1 = np.zeros((self.N,)) self.P1 = np.zeros((self.M,)) self.PX = np.zeros((self.M, self.D)) self.Np = 0 if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0): msg = "Expected a positive value for regularization parameter " msg += f"alpha. Instead got: {alpha}" raise ValueError(msg) if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0): msg = "Expected a positive value for the width of the coherent " msg += f"Gaussian kernel. Instead got: {beta}" self.alpha = 2 if alpha is None else alpha self.beta = 2 if beta is None else beta self.W = np.zeros((self.M, self.D)) self.G = gaussian_kernel(self.Y, self.beta) self.low_rank = low_rank self.num_eig = num_eig if self.low_rank is True: self.Q, self.S = low_rank_eigen(self.G, self.num_eig) self.inv_S = np.diag(1.0 / self.S) self.S = np.diag(self.S) self.E = 0.0
[docs] @warning_for_keywords() def register(self, *, callback=lambda **kwargs: None): """ Perform the EM registration. Parameters ---------- callback: function A function that will be called after each iteration. Can be used to visualize the registration process. Returns ------- self.TY: numpy array MxD array of transformed source points. registration_parameters: Returned params dependent on registration method used. """ self.transform_point_cloud() while self.iteration < self.max_iterations and self.diff > self.tolerance: self.iterate() if callable(callback): kwargs = { "iteration": self.iteration, "error": self.q, "X": self.X, "Y": self.TY, } callback(**kwargs) return self.TY, self.get_registration_parameters()
[docs] def update_transform(self): """ Calculate a new estimate of the deformable transformation. See Eq. 22 of https://arxiv.org/pdf/0905.2635.pdf. """ if self.low_rank is False: A = np.dot(np.diag(self.P1), self.G) + self.alpha * self.sigma2 * np.eye( self.M ) B = self.PX - np.dot(np.diag(self.P1), self.Y) self.W = np.linalg.solve(A, B) elif self.low_rank is True: # Matlab code equivalent can be found here: # https://github.com/markeroon/matlab-computer-vision-routines/tree/master/third_party/CoherentPointDrift dP = np.diag(self.P1) dPQ = np.matmul(dP, self.Q) F = self.PX - np.matmul(dP, self.Y) self.W = ( 1 / (self.alpha * self.sigma2) * ( F - np.matmul( dPQ, ( np.linalg.solve( ( self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ) ), (np.matmul(self.Q.T, F)), ) ), ) ) ) QtW = np.matmul(self.Q.T, self.W) self.E = self.E + self.alpha / 2 * np.trace( np.matmul(QtW.T, np.matmul(self.S, QtW)) )
[docs] @warning_for_keywords() def transform_point_cloud(self, *, Y=None): """Update a point cloud using the new estimate of the deformable transformation. Parameters ---------- Y: numpy array, optional Array of points to transform - use to predict on new set of points. Best for predicting on new points not used to run initial registration. If None, self.Y used. Returns ------- If Y is None, returns None. Otherwise, returns the transformed Y. """ if Y is not None: G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y) return Y + np.dot(G, self.W) else: if self.low_rank is False: self.TY = self.Y + np.dot(self.G, self.W) elif self.low_rank is True: self.TY = self.Y + np.matmul( self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)) ) return
[docs] def update_variance(self): """Update the variance of the mixture model. This is using the new estimate of the deformable transformation. See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf. """ qprev = self.sigma2 # The original CPD paper does not explicitly calculate the objective # functional. This functional will include terms from both the negative # log-likelihood and the Gaussian kernel used for regularization. self.q = np.inf xPx = np.dot( np.transpose(self.Pt1), np.sum(np.multiply(self.X, self.X), axis=1) ) yPy = np.dot( np.transpose(self.P1), np.sum(np.multiply(self.TY, self.TY), axis=1) ) trPXY = np.sum(np.multiply(self.TY, self.PX)) self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D) if self.sigma2 <= 0: self.sigma2 = self.tolerance / 10 # Here we use the difference between the current and previous # estimate of the variance as a proxy to test for convergence. self.diff = np.abs(self.sigma2 - qprev)
[docs] def get_registration_parameters(self): """Return the current estimate of the deformable transformation parameters. Returns ------- self.G: numpy array Gaussian kernel matrix. self.W: numpy array Deformable transformation matrix. """ return self.G, self.W
[docs] def iterate(self): """Perform one iteration of the EM algorithm.""" self.expectation() self.maximization() self.iteration += 1
[docs] def expectation(self): """Compute the expectation step of the EM algorithm.""" # (M, N) P = np.sum((self.X[None, :, :] - self.TY[:, None, :]) ** 2, axis=2) P = np.exp(-P / (2 * self.sigma2)) c = ( (2 * np.pi * self.sigma2) ** (self.D / 2) * self.w / (1.0 - self.w) * self.M / self.N ) den = np.sum(P, axis=0, keepdims=True) # (1, N) den = np.clip(den, np.finfo(self.X.dtype).eps, None) + c self.P = np.divide(P, den) self.Pt1 = np.sum(self.P, axis=0) self.P1 = np.sum(self.P, axis=1) self.Np = np.sum(self.P1) self.PX = np.matmul(self.P, self.X)
[docs] def maximization(self): """Compute the maximization step of the EM algorithm.""" self.update_transform() self.transform_point_cloud() self.update_variance()