import warnings
import numpy as np
import pywt
from simpeg import utils
from simpeg.regularization import BaseRegularization
from scipy.sparse import eye
__all__ = ["WaveletRegularization1D"]
[docs]
class WaveletRegularization1D(BaseRegularization):
"""
Wavelet-based Regularization.
This class regularizes the inverse problem by minimizing the complexity in the wavelet domain via a sparsity constraint (see Deleersnyder et al., 2021).
This class fits within the modular SimPEG framework. For more information, see
- https://simpeg.xyz/
- Cockett, R., Kang, S., Heagy, L. J., Pidlisecky, A., & Oldenburg, D. W. (2015). SimPEG: An open source framework for simulation and gradient based parameter estimation in geophysical applications. Computers & Geosciences, 85, 142-154.
**Optional Inputs**
:param discretize.base.BaseMesh mesh: SimPEG mesh
:param int nP: number of parameters
:param IdentityMap mapping: regularization mapping, takes the model from model space to the space you want to regularize in
:param numpy.ndarray mref: reference model - our method does not support reference models
:param numpy.ndarray indActive: active cell indices for reducing the size of differential operators in the definition of a regularization mesh
"""
[docs]
def __init__(self, mesh, orientation="x", wav="db1", **kwargs):
"""
Regularization for the 1D wavelet transform.
:param mesh: SimPEG mesh
:param orientation: orientation of the regularization
:param wav: wavelet type (default db1, which is blocky)
"""
self.orientation = orientation
self.p = 1 # See Deleersnyder et al., 2021 for details.
self.eps = (
1e-6 # perturbing parameter, default is 1e-6. Should be smaller than 1e-4.
)
self.mesh = mesh
assert self.orientation in [
"x",
"y",
"z",
], "Orientation must be 'x', 'y' or 'z'"
if self.orientation == "x":
self.wavelets = Wavelet(wav, mesh.shape_cells[0], **kwargs)
elif self.orientation == "y":
self.wavelets = Wavelet(wav, mesh.shape_cells[1], **kwargs)
assert mesh.dim > 1, (
"Mesh must have at least 2 dimensions to regularize along the "
"y-direction"
)
elif self.orientation == "z":
self.wavelets = Wavelet(wav, mesh.shape_cells[2], **kwargs)
assert mesh.dim > 2, (
"Mesh must have at least 3 dimensions to regularize along the "
"z-direction"
)
# Generate the scale-dependent-weights for each coefficient in X
self.R = self._regularization_matrix()
# pop DWTlevel from kwargs -- otherwise the base Simpeg class throws an error
# Alternatively, copy the dictionary for later use in the wavelet transform
self.DWTlevel = kwargs.pop("DWTlevel", 1)
super(WaveletRegularization1D, self).__init__(mesh=mesh, **kwargs)
@property
def _multiplier_pair(self):
return "alpha_{orientation}".format(orientation=self.orientation)
[docs]
@utils.timeIt
def __call__(self, m):
r"""
We use a $\ell_1$ perturbed Ekblom measure as differentiable sparsity measure.
.. math::
r(m) = \sum_{i,j} R_{ij}\sqrt{ \left(X_{ij}\right)^2 + \epsilon}
"""
# Do the wavelet transform on each 1D snippet
X = self.wavelets.W @ m.reshape(-1, 1)
return np.sum(self.R * np.sqrt(X ** 2 + self.eps)) # the actual measure
[docs]
@utils.timeIt
def deriv(self, m):
r"""
Derivative of the measure.
:param m: model
The regularization in wavelet domain is:
.. math::
R(x) = \sum_j \sqrt{x_j^2 + \epsilon}
So the derivative is straightforward:
.. math::
\frac{\partial R(x)}{\partial x_j} = \sum_j \frac{x_j}{\sqrt{x_j^2 + \epsilon}}
And using the chain rule to model space:
.. math::
\frac{\partial R(m)}{\partial m_j} = \frac{\partial R(m)}{\partial x_i}\frac{\partial x_i}{\partial m_j}
with \frac{\partial x}{\partial m} = W
"""
mD = self.mapping.deriv(m)
# Do the wavelet transform
X = self.wavelets.W @ m.reshape(-1, 1)
# Generate derivative w.r.t. x
deriv_x = self.R * X / np.sqrt(X ** 2 + self.eps)
# Chain rule w.r.t. m
deriv_m = self.wavelets.W.T @ deriv_x
return (mD.T * deriv_m).flatten() # Chain rule w.r.t. SimPEG mapping
[docs]
@utils.timeIt
def deriv2(self, m, v=None):
r"""
Second derivative of the measure.
:param numpy.ndarray m: geophysical model
:param numpy.ndarray v: vector to multiply
:rtype: scipy.sparse.csr_matrix
:return: WtW, or if v is supplied WtW*v (numpy.ndarray)q
The second derivative of the perturbed Ekblom measure is highly unstable for small epsilon. Most methods do not
use Hessian information, except for preconditioning (see e.g., optimization -> InexactGaussNewton) . Therefore,
the unit matrix is used as Hessian. This results in a more stable optimization routine.
"""
mD = self.mapping.deriv(m)
if v is None:
return mD.T * eye(m.size)
return mD.T * v
[docs]
def _generate_scale_dependency_vector(self, wavelet):
r"""
Generate the scale-dependent-weights for each coefficient in X.
:param wavelet: wavelet object
Wavelet-coefficients corresponding to small-scale effects of the model are penalized more heavily.
The scaling coefficient is never zero, so no regularization on the scaling coefficients.
.. math::
x = [v_{0,0}, w_{0,0}, w_{1,0},w_{1,1}, w_{2,1},w_{2,1}, \cdots, w_{n,k}, \cdots ]
\phi_m(x) = \frac{1}{E} \sum_{n}^{N} 2^n\sum_k \mu(w_n,k)
:param wavelet:contains info about the specific wavelet-transform
"""
# Do wavelet decomposition (=transform)
coeffs = pywt.wavedec(np.ones(wavelet.n_m), wavelet.wav, level=wavelet.DWTlevel)
# returns a list of lists with scaling/wavelet coefficients per scale of resolution
scale_dependency_vector = np.hstack(
[2 ** (j * self.p) * np.ones(c.shape) for j, c in enumerate(coeffs)]
)
scale_dependency_vector[
: coeffs[0].size
] = 0 # No regularization on scaling coefficients
return scale_dependency_vector.reshape(-1, 1) / np.linalg.norm(
scale_dependency_vector
) # Normalization, only valid vor 1D inversion (as in Deleersnyder et al, 2021)
[docs]
def _regularization_matrix(self):
r"""
Generate the regularization matrix. This maps the scale-dependency on each element in the wavelet domain matrix X.
"""
if self.mesh.dim == 1:
n = 1
else:
raise NotImplementedError("Future release")
scale_dependency_vector = self._generate_scale_dependency_vector(self.wavelets)
return np.tile(scale_dependency_vector.reshape(-1, 1), (1, n))
class Wavelet:
r"""
The object containing all specific functionalities for a wavelet type.
Parameters
----------
wav : string
The wavelet type (typically family + str(number of vanishing moments) e.g. Daubechies 1 --> db1)
n_m : integer
The number of model parameters in model space (i.e. the length of vector m)
DWTlevel : integer
The level of the discrete wavelet transform
signal_extension : string
The type/mode of signal extension, used in the discrete wavelet transform
Attributes
----------
W : numpy array of size n_x by n_m
Transformation matrix of wavelet transform
See Also
--------
...
"""
def __init__(self, wav, n_m, DWTlevel=None, signal_extension=None):
self.W = None
self.n_m = n_m
self.wav = wav
self.DWTlevel = DWTlevel
self.signal_extension = signal_extension
self._update_W()
if self.W is None:
raise Exception("Init failed")
self.n_x = self.W.shape[0]
@property
def n_m(self):
"""The number of model parameters (in model space)"""
return self._n_m
@n_m.setter
def n_m(self, n):
if n <= 0:
raise Exception(
"n_m are the number of model parameters, thus strictly positive"
)
else:
self._n_m = n
if self.W is not None:
self._update_W()
@property
def wav(self):
r"""Wavelet family to use.
See Deleersnyder et al, 2021 for the rationale behind the choice of the optimal wavelet.
In general, Daubechies (db) wavelets are prefered.
- db1 yields blocky inversion models
- db2-db4 yield inversion models with sharp interfaces
- db5+ yield smooth inversion models
The discretization of the inversion model also plays a role.
Changing the discretization may affect the optimal 'choice' for the wavelet.
"""
if self._wav is None:
raise Exception("The wavelet basis function is None")
else:
return self._wav
@wav.setter
def wav(self, type_):
r"""
Set the wavelet family to use.
:param type_: string
"""
if type_ not in pywt.wavelist():
raise Exception(
"unknown wavelet type, use names from " + str(pywt.wavelist())
)
else:
self._wav = type_
if self.W is not None:
self._update_W()
@property
def DWTlevel(self):
"""The level of decomposition of the discrete wavelet transform"""
if self._DWTlevel is None:
raise Exception("The discrete wavelet transform level (DWTlevel) is None")
else:
return self._DWTlevel
@DWTlevel.setter
def DWTlevel(self, level):
"""
Set the level of decomposition of the discrete wavelet transform
:param level: integer
"""
maxlevel = pywt.dwt_max_level(self.n_m, self.wav)
if level is None:
self._DWTlevel = maxlevel
else:
if level > maxlevel:
warnings.warn(
"Boundary effects: The user-defined DWTlevel exceeds the suggested maximum DWT level of "
+ str(maxlevel)
)
self._DWTlevel = level
if self.W is not None:
self._update_W()
@property
def signal_extension(self):
"""Due to the cascading filter banks algorithm, an extrapolation method is required. Choose the method which
introduces the least artifacts. In general, "smooth" is a good choice."""
if self._signal_extension is None:
raise Exception("The signal extension type is None")
else:
return self._signal_extension
@signal_extension.setter
def signal_extension(self, type_):
"""
Set the signal extension type
:param type_: string
"""
if type_ is None:
self._signal_extension = "smooth"
elif type_ not in pywt.Modes.modes:
raise Exception("Typo in signal extension, choose from " + pywt.Modes.modes)
else:
self._signal_extension = type_
if self.W is not None:
self._update_W()
def _update_W(self):
"""
Update the wavelet basis function
"""
self.W = np.hstack(
pywt.wavedec(
np.eye(self.n_m),
self.wav,
level=self.DWTlevel,
mode=self.signal_extension,
axis=1,
)
).T