import numpy as np
from dipy.denoise import nlmeans_block
from dipy.testing.decorators import warning_for_keywords
"""
Functions for Wavelet Transforms in 3D domain
Code adapted from
WAVELET SOFTWARE AT POLYTECHNIC UNIVERSITY, BROOKLYN, NY
https://eeweb.engineering.nyu.edu/iselesni/WaveletSoftware/
"""
[docs]
def cshift3D(x, m, d):
"""3D Circular Shift
Parameters
----------
x : 3D ndarray
N1 by N2 by N3 array
m : int
amount of shift
d : int
dimension of shift (d = 1,2,3)
Returns
-------
y : 3D ndarray
array x will be shifted by m samples down
along dimension d
"""
s = x.shape
idx = (np.array(range(s[d])) + (s[d] - m % s[d])) % s[d]
idx = np.array(idx, dtype=np.int64)
if d == 0:
return x[idx, :, :]
elif d == 1:
return x[:, idx, :]
else:
return x[:, :, idx]
[docs]
def permutationinverse(perm):
"""
Function generating inverse of the permutation
Parameters
----------
perm : 1D array
Returns
-------
inverse : 1D array
permutation inverse of the input
"""
inverse = [0] * len(perm)
for i, p in enumerate(perm):
inverse[p] = i
return inverse
[docs]
def afb3D_A(x, af, d):
"""3D Analysis Filter Bank
(along one dimension only)
Parameters
----------
x : 3D ndarray
N1xN2xN2 matrix, where min(N1,N2,N3) > 2*length(filter)
(Ni are even)
af : 2D ndarray
analysis filter for the columns
af[:, 1] - lowpass filter
af[:, 2] - highpass filter
d : int
dimension of filtering (d = 1, 2 or 3)
Returns
-------
lo : 1D array
lowpass subbands
hi : 1D array
highpass subbands
"""
lpf = af[:, 0]
hpf = af[:, 1]
# permute dimensions of x so that dimension d is first.
p = [(i + d) % 3 for i in range(3)]
x = x.transpose(p)
# filter along dimension 0
(N1, N2, N3) = x.shape
L = af.shape[0] // 2
x = cshift3D(x, -L, 0)
n1Half = N1 // 2
lo = np.zeros((L + n1Half, N2, N3))
hi = np.zeros((L + n1Half, N2, N3))
for k in range(N3):
lo[:, :, k] = nlmeans_block.firdn(x[:, :, k], lpf)
lo[:L] = lo[:L] + lo[n1Half : n1Half + L, :, :]
lo = lo[:n1Half, :, :]
for k in range(N3):
hi[:, :, k] = nlmeans_block.firdn(x[:, :, k], hpf)
hi[:L] = hi[:L] + hi[n1Half : n1Half + L, :, :]
hi = hi[:n1Half, :, :]
# permute dimensions of x (inverse permutation)
q = permutationinverse(p)
lo = lo.transpose(q)
hi = hi.transpose(q)
return lo, hi
[docs]
def sfb3D_A(lo, hi, sf, d):
"""3D Synthesis Filter Bank
(along single dimension only)
Parameters
----------
lo : 1D array
lowpass subbands
hi : 1D array
highpass subbands
sf : 2D ndarray
synthesis filters
d : int
dimension of filtering
Returns
-------
y : 3D ndarray
the N1xN2xN3 matrix
"""
lpf = sf[:, 0]
hpf = sf[:, 1]
# permute dimensions of lo and hi so that dimension d is first.
p = [(i + d) % 3 for i in range(3)]
lo = lo.transpose(p)
hi = hi.transpose(p)
(N1, N2, N3) = lo.shape
N = 2 * N1
L = sf.shape[0]
y = np.zeros((N + L - 2, N2, N3))
for k in range(N3):
y[:, :, k] = np.array(nlmeans_block.upfir(lo[:, :, k], lpf)) + np.array(
nlmeans_block.upfir(hi[:, :, k], hpf)
)
y[: (L - 2), :, :] = y[: (L - 2), :, :] + y[N : (N + L - 2), :, :]
y = y[:N, :, :]
y = cshift3D(y, 1 - L / 2, 0)
# permute dimensions of y (inverse permutation)
q = permutationinverse(p)
y = y.transpose(q)
return y
[docs]
@warning_for_keywords()
def sfb3D(lo, hi, sf1, *, sf2=None, sf3=None):
"""3D Synthesis Filter Bank
Parameters
----------
lo : 1D array
lowpass subbands
hi : 1D array
highpass subbands
sfi : 2D ndarray
synthesis filters for dimension i
Returns
-------
y : 3D ndarray
output array
"""
if sf2 is None:
sf2 = sf1
if sf3 is None:
sf3 = sf1
LLL = lo
LLH = hi[0]
LHL = hi[1]
LHH = hi[2]
HLL = hi[3]
HLH = hi[4]
HHL = hi[5]
HHH = hi[6]
# filter along dimension 2
LL = sfb3D_A(LLL, LLH, sf3, 2)
LH = sfb3D_A(LHL, LHH, sf3, 2)
HL = sfb3D_A(HLL, HLH, sf3, 2)
HH = sfb3D_A(HHL, HHH, sf3, 2)
# filter along dimension 1
L = sfb3D_A(LL, LH, sf2, 1)
H = sfb3D_A(HL, HH, sf2, 1)
# filter along dimension 0
y = sfb3D_A(L, H, sf1, 0)
return y
[docs]
@warning_for_keywords()
def afb3D(x, af1, *, af2=None, af3=None):
"""3D Analysis Filter Bank
Parameters
----------
x : 3D ndarray
N1 by N2 by N3 array matrix, where
1) N1, N2, N3 all even
2) N1 >= 2*len(af1)
3) N2 >= 2*len(af2)
4) N3 >= 2*len(af3)
afi : 2D ndarray
analysis filters for dimension i
afi[:, 1] - lowpass filter
afi[:, 2] - highpass filter
Returns
-------
lo : 1D array
lowpass subband
hi : 1D array
highpass subbands, h[d]- d = 1..7
"""
if af2 is None:
af2 = af1
if af3 is None:
af3 = af1
# filter along dimension 0
L, H = afb3D_A(x, af1, 0)
# filter along dimension 1
LL, LH = afb3D_A(L, af2, 1)
HL, HH = afb3D_A(H, af2, 1)
# filter along dimension 3
LLL, LLH = afb3D_A(LL, af3, 2)
LHL, LHH = afb3D_A(LH, af3, 2)
HLL, HLH = afb3D_A(HL, af3, 2)
HHL, HHH = afb3D_A(HH, af3, 2)
return LLL, [LLH, LHL, LHH, HLL, HLH, HHL, HHH]
[docs]
def dwt3D(x, J, af):
"""3-D Discrete Wavelet Transform
Parameters
----------
x : 3D ndarray
N1 x N2 x N3 matrix
1) Ni all even
2) min(Ni) >= 2^(J-1)*length(af)
J : int
number of stages
af : 2D ndarray
analysis filters
Returns
-------
w : cell array
wavelet coefficients
"""
w = [None] * (J + 1)
for k in range(J):
x, w[k] = afb3D(x, af, af2=af, af3=af)
w[J] = x
return w
[docs]
def idwt3D(w, J, sf):
"""
Inverse 3-D Discrete Wavelet Transform
Parameters
----------
w : cell array
wavelet coefficient
J : int
number of stages
sf : 2D ndarray
synthesis filters
Returns
-------
y : 3D ndarray
output array
"""
y = w[J]
for k in range(J)[::-1]:
y = sfb3D(y, w[k], sf, sf2=sf, sf3=sf)
return y