Skip to content

Commit 15d1be6

Browse files
committed
Added HypertoroidalFourierFilter
1 parent 9f9c951 commit 15d1be6

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import numpy as np
2+
import warnings
3+
4+
from .abstract_hypertoroidal_filter import AbstractHypertoroidalFilter
5+
from pyrecest.distributions.hypertorus.hypertoroidal_fourier_distribution import HypertoroidalFourierDistribution
6+
from pyrecest.distributions import HypertoroidalUniformDistribution
7+
8+
class HypertoroidalFourierFilter(AbstractHypertoroidalFilter):
9+
def __init__(self, noOfCoefficients, transformation='sqrt'):
10+
self.hfd = HypertoroidalFourierDistribution.from_distribution(
11+
HypertoroidalUniformDistribution(np.size(noOfCoefficients)), noOfCoefficients, transformation)
12+
13+
@property
14+
def filter_state(self):
15+
return self._filter_state
16+
17+
@filter_state.setter
18+
def filter_state(self, new_state):
19+
if np.ndim(self.hfd.C) != np.ndim(new_state.C):
20+
warnings.warn('The new state has a different dimensionality.')
21+
elif self.hfd.C.shape != new_state.C.shape:
22+
warnings.warn('The new state has a different number of coefficients.')
23+
self.hfd = new_state
24+
25+
def predict_identity(self, d_sys):
26+
size_hfd_c = self.hfd.C.shape
27+
if not isinstance(d_sys, HypertoroidalFourierDistribution):
28+
warnings.warn("PredictIdentity:automaticConversion: dSys is not a HypertoroidalFourierDistribution. "
29+
"Transforming with a number of coefficients that is equal to that of the filter. "
30+
"For non-varying noises, transforming once is much more efficient and should be preferred.")
31+
d_sys = HypertoroidalFourierDistribution.from_distribution(
32+
d_sys, size_hfd_c[size_hfd_c > 1], self.hfd.transformation)
33+
self.hfd = self.hfd.convolve(d_sys, size_hfd_c[size_hfd_c > 1])
34+
35+
def predictNonlinearViaTransitionDensity(self, fTrans, truncateJointSqrt=True):
36+
dimC = np.shape(self.hfd.C)
37+
warnStruct = warnings.catch_warnings()
38+
warnings.simplefilter('ignore')
39+
# rest of the method body...
40+
if self.hfd.transformation == 'identity' or not truncateJointSqrt:
41+
warnings.resetwarnings()
42+
self.hfd = HypertoroidalFourierDistribution(CPredictedId,'identity')
43+
else:
44+
self.hfd = HypertoroidalFourierDistribution(CPredictedId,'identity')
45+
warnings.resetwarnings()
46+
47+
if fTrans.transformation == 'sqrt':
48+
self.hfd = self.hfd.transformViaFFT('sqrt',dimC[dimC>1])
49+
50+
def updateNonlinear(self, likelihood, z: np.ndarray | None = None):
51+
"""
52+
Performs an update for an arbitrary likelihood function and a measurement. If the measurement z is not
53+
given, assume that likelihood (for varying x) is given as a hfd. Otherwise, transform it.
54+
55+
Parameters:
56+
likelihood f(z|x):
57+
Either given as HypertoroidalFourierDistribution or as a function. If given as a function, we assume
58+
that it takes matrices (same convention as .pdf) as input for both measurement and state.
59+
measurement z:
60+
Used as input for likelihood. Is repmatted if likelihood is to be evaluated at multiple points.
61+
"""
62+
63+
# Check if z is given
64+
if z is None:
65+
assert isinstance(likelihood, HypertoroidalFourierDistribution)
66+
else:
67+
# If z is given, assume likelihood is a function
68+
def func(*args):
69+
reshaped_likelihood = likelihood(np.repeat(z, len(args[0]), axis=1), np.concatenate([i.flatten() for i in args], axis=0))
70+
return reshaped_likelihood.reshape(args[0].shape)
71+
72+
likelihood = HypertoroidalFourierDistribution.from_function(func, self.hfd.C.shape, self.hfd.transformation)
73+
74+
self.hfd = self.hfd.multiply(likelihood, self.hfd.C.shape)

0 commit comments

Comments
 (0)