|
| 1 | +import numpy as np |
| 2 | +import warnings |
| 3 | + |
| 4 | +from .abstract_hypertoroidal_filter import AbstractHypertoroidalFilter |
| 5 | +from pyrecest.distributions.hypertorus.hypertoroidal_fourier_distribution import HypertoroidalFourierDistribution |
| 6 | +from pyrecest.distributions import HypertoroidalUniformDistribution |
| 7 | + |
| 8 | +class HypertoroidalFourierFilter(AbstractHypertoroidalFilter): |
| 9 | + def __init__(self, noOfCoefficients, transformation='sqrt'): |
| 10 | + self.hfd = HypertoroidalFourierDistribution.from_distribution( |
| 11 | + HypertoroidalUniformDistribution(np.size(noOfCoefficients)), noOfCoefficients, transformation) |
| 12 | + |
| 13 | + @property |
| 14 | + def filter_state(self): |
| 15 | + return self._filter_state |
| 16 | + |
| 17 | + @filter_state.setter |
| 18 | + def filter_state(self, new_state): |
| 19 | + if np.ndim(self.hfd.C) != np.ndim(new_state.C): |
| 20 | + warnings.warn('The new state has a different dimensionality.') |
| 21 | + elif self.hfd.C.shape != new_state.C.shape: |
| 22 | + warnings.warn('The new state has a different number of coefficients.') |
| 23 | + self.hfd = new_state |
| 24 | + |
| 25 | + def predict_identity(self, d_sys): |
| 26 | + size_hfd_c = self.hfd.C.shape |
| 27 | + if not isinstance(d_sys, HypertoroidalFourierDistribution): |
| 28 | + warnings.warn("PredictIdentity:automaticConversion: dSys is not a HypertoroidalFourierDistribution. " |
| 29 | + "Transforming with a number of coefficients that is equal to that of the filter. " |
| 30 | + "For non-varying noises, transforming once is much more efficient and should be preferred.") |
| 31 | + d_sys = HypertoroidalFourierDistribution.from_distribution( |
| 32 | + d_sys, size_hfd_c[size_hfd_c > 1], self.hfd.transformation) |
| 33 | + self.hfd = self.hfd.convolve(d_sys, size_hfd_c[size_hfd_c > 1]) |
| 34 | + |
| 35 | + def predictNonlinearViaTransitionDensity(self, fTrans, truncateJointSqrt=True): |
| 36 | + dimC = np.shape(self.hfd.C) |
| 37 | + warnStruct = warnings.catch_warnings() |
| 38 | + warnings.simplefilter('ignore') |
| 39 | + # rest of the method body... |
| 40 | + if self.hfd.transformation == 'identity' or not truncateJointSqrt: |
| 41 | + warnings.resetwarnings() |
| 42 | + self.hfd = HypertoroidalFourierDistribution(CPredictedId,'identity') |
| 43 | + else: |
| 44 | + self.hfd = HypertoroidalFourierDistribution(CPredictedId,'identity') |
| 45 | + warnings.resetwarnings() |
| 46 | + |
| 47 | + if fTrans.transformation == 'sqrt': |
| 48 | + self.hfd = self.hfd.transformViaFFT('sqrt',dimC[dimC>1]) |
| 49 | + |
| 50 | + def updateNonlinear(self, likelihood, z: np.ndarray | None = None): |
| 51 | + """ |
| 52 | + Performs an update for an arbitrary likelihood function and a measurement. If the measurement z is not |
| 53 | + given, assume that likelihood (for varying x) is given as a hfd. Otherwise, transform it. |
| 54 | +
|
| 55 | + Parameters: |
| 56 | + likelihood f(z|x): |
| 57 | + Either given as HypertoroidalFourierDistribution or as a function. If given as a function, we assume |
| 58 | + that it takes matrices (same convention as .pdf) as input for both measurement and state. |
| 59 | + measurement z: |
| 60 | + Used as input for likelihood. Is repmatted if likelihood is to be evaluated at multiple points. |
| 61 | + """ |
| 62 | + |
| 63 | + # Check if z is given |
| 64 | + if z is None: |
| 65 | + assert isinstance(likelihood, HypertoroidalFourierDistribution) |
| 66 | + else: |
| 67 | + # If z is given, assume likelihood is a function |
| 68 | + def func(*args): |
| 69 | + reshaped_likelihood = likelihood(np.repeat(z, len(args[0]), axis=1), np.concatenate([i.flatten() for i in args], axis=0)) |
| 70 | + return reshaped_likelihood.reshape(args[0].shape) |
| 71 | + |
| 72 | + likelihood = HypertoroidalFourierDistribution.from_function(func, self.hfd.C.shape, self.hfd.transformation) |
| 73 | + |
| 74 | + self.hfd = self.hfd.multiply(likelihood, self.hfd.C.shape) |
0 commit comments