1+ from beartype import beartype
2+ from collections .abc import Callable
3+ import numpy as np
4+ from scipy .fft import fft , ifft , fftshift , ifftshift
5+ from scipy .signal import fftconvolve
6+ from .abstract_circular_filter import AbstractCircularFilter
7+ from pyrecest .distributions import CircularFourierDistribution , AbstractCircularDistribution , AbstractHypertoroidalDistribution
8+ from pyrecest .distributions .circle .circular_uniform_distribution import CircularUniformDistribution
9+ from pyrecest .distributions .hypertorus .hypertoroidal_fourier_distribution import HypertoroidalFourierDistribution
10+ from pyrecest .distributions .hypertorus .toroidal_fourier_distribution import ToroidalFourierDistribution
11+ from .hypertoroidal_fourier_filter import HypertoroidalFourierFilter
12+ import copy
13+ import warnings
14+ import scipy
15+
16+ class CircularFourierFilter (AbstractCircularFilter , HypertoroidalFourierFilter ):
17+ def __init__ (self , no_of_coefficients , transformation = 'sqrt' ):
18+ assert transformation == 'sqrt' or transformation == 'identity'
19+ AbstractCircularFilter .__init__ (self , CircularFourierDistribution .from_distribution (
20+ CircularUniformDistribution (), no_of_coefficients , transformation ))
21+
22+ @property
23+ def filter_state (self ):
24+ return self ._filter_state
25+
26+ @filter_state .setter
27+ def filter_state (self , new_state ):
28+ assert isinstance (new_state , AbstractCircularDistribution )
29+ if not isinstance (new_state , CircularFourierDistribution ):
30+ state_to_set = CircularFourierDistribution .from_distribution (
31+ new_state , 2 * np .size (self .filter_state .a ) - 1 , self .filter_state .transformation )
32+ else :
33+ state_to_set = copy .deepcopy (new_state )
34+ if self ._filter_state .transformation != state_to_set .transformation :
35+ warnings .warn ("Warning: New density is transformed differently." )
36+ if np .size (new_state .a ) != np .size (self .filter_state .a ):
37+ warnings .warn ("Warning: New density has a different number of coefficients." )
38+ self ._filter_state = state_to_set
39+
40+ def get_estimate (self ):
41+ return self .filter_state
42+
43+ def get_point_estimate (self ):
44+ return self .filter_state .mean_direction ()
45+
46+ def predict_identity (self , d_sys ):
47+ if isinstance (d_sys , AbstractCircularDistribution ):
48+ if not isinstance (d_sys , CircularFourierDistribution ):
49+ warnings .warn ("Warning: d_sys is not a FourierDistribution. Transforming with a number of coefficients that is equal to that of the filter. For non-varying noises, transforming once is much more efficient and should be preferred." )
50+ d_sys = CircularFourierDistribution .from_distribution (d_sys , 2 * len (self .filter_state .a ) - 1 , self .filter_state .transformation )
51+ self .filter_state = self .filter_state .convolve (d_sys )
52+
53+ elif isinstance (d_sys , np .ndarray ):
54+ assert self .filter_state .transformation == 'sqrt' , "Only sqrt transformation currently supported"
55+ assert d_sys .size == self .n , "Assume that as many grid points are used as there are coefficients."
56+ fdvals = np .fft .ifftshift (self .filter_state .c ) ** 2
57+ f_pred_vals = fftconvolve (fdvals , d_sys , mode = 'same' ) * self .n * 2 * np .pi
58+ self .filter_state = CircularFourierDistribution (transformation = 'sqrt' , c = np .fft .fftshift (np .fft .fft (np .sqrt (f_pred_vals ))) / self .n )
59+
60+ else :
61+ raise ValueError ("Input format of d_sys is not supported" )
62+
63+ def update_identity (self , d_meas , z ):
64+ assert isinstance (d_meas , AbstractCircularDistribution )
65+ if not isinstance (d_meas , CircularFourierDistribution ):
66+ print ("Warning: d_meas is not a FourierDistribution. Transforming with a number of coefficients that is equal to that of the filter. For non-varying noises, transforming once is much more efficient and should be preferred." )
67+ d_meas = CircularFourierDistribution .from_distribution (d_meas , 2 * len (self .filter_state .a ) - 1 , self .filter_state .transformation )
68+ d_meas_shifted = d_meas .shift (z )
69+ self .filter_state = self .filter_state .multiply (d_meas_shifted , 2 * len (self .filter_state .a ) - 1 )
70+
71+ def predict_nonlinear (self , f , noise_distribution , truncate_joint_sqrt = True ):
72+ assert isinstance (noise_distribution , AbstractCircularDistribution )
73+ assert callable (f )
74+ f_trans = lambda xkk , xk : np .reshape (noise_distribution .pdf (xkk .T - f (xk .T )), xk .shape )
75+ self .predict_nonlinear_via_transition_density (f_trans , truncate_joint_sqrt )
76+
77+ def updateIdentity (self , dMeas , z ):
78+ """
79+ Updates assuming identity measurement model, i.e.,
80+ z(k) = x(k) + v(k) mod 2pi,
81+ where v(k) is additive noise given by dMeas.
82+ The modulo operation is carried out componentwise.
83+
84+ Parameters:
85+ dMeas (AbstractHypertoroidalDistribution):
86+ distribution of additive noise
87+ z (dim x 1 vector):
88+ measurement in [0, 2pi)^dim
89+ """
90+
91+ assert isinstance (dMeas , AbstractHypertoroidalDistribution )
92+
93+ if not isinstance (dMeas , HypertoroidalFourierDistribution ):
94+ print ("Update:automaticConversion: dMeas is not a HypertoroidalFourierDistribution. \
95+ Transforming with an amount of coefficients that is equal to that of the filter. \
96+ For non-varying noises, transforming once is much more efficient and should be preferred." )
97+ sizeHfdC = np .shape (self .hfd .C ) # Needed for workaround for 1D case
98+ dMeas = HypertoroidalFourierDistribution .from_distribution (dMeas , sizeHfdC [sizeHfdC > 1 ], self .hfd .transformation )
99+
100+ assert np .shape (z ) == (self .hfd .dim , 1 )
101+
102+ dMeasShifted = dMeas .shift (z )
103+ self .hfd = self .hfd .multiply (dMeasShifted , np .shape (self .hfd .C ))
104+
105+ def predict_nonlinear_via_transition_density (self , f_trans , truncate_joint_sqrt = True ):
106+
107+ if callable (f_trans ):
108+ f_trans = ToroidalFourierDistribution .from_function (f_trans , self .filter_state .n * np .array ([1 , 1 ]), dim = 2 , desired_transformation = self .filter_state .transformation )
109+ else :
110+ assert self .filter_state .transformation == f_trans .transformation
111+
112+ if self .filter_state .transformation == 'identity' :
113+ c_predicted_id = (2 * np .pi ) ** 2 * scipy .signal .convolve2d (f_trans .C , np .atleast_2d (self .filter_state .c ), mode = 'valid' )
114+ elif self .filter_state .transformation == 'sqrt' :
115+ if not truncate_joint_sqrt :
116+ c_joint_sqrt = scipy .signal .convolve2d (np .sqrt (2 * np .pi ) * f_trans .C , self .filter_state .c , mode = 'full' )
117+ else :
118+ c_joint_sqrt = scipy .signal .convolve2d (np .sqrt (2 * np .pi ) * f_trans .C , self .filter_state .c , mode = 'same' )
119+
120+ additional_columns = 2 * len (self .filter_state .b )
121+ c_predicted_id = 2 * np .pi * scipy .signal .convolve2d (
122+ np .pad (c_joint_sqrt , ((additional_columns , additional_columns ), (0 , 0 ))),
123+ c_joint_sqrt ,
124+ mode = 'valid'
125+ )
126+
127+ if self .filter_state .transformation == 'identity' or not truncate_joint_sqrt :
128+ self .filter_state = CircularFourierDistribution (transformation = 'identity' , c = c_predicted_id )
129+ else :
130+ self .filter_state = CircularFourierDistribution (transformation = 'identity' , c = c_predicted_id )
131+
132+ if f_trans .transformation == 'sqrt' :
133+ self .filter_state = self .filter_state .transform_via_fft ('sqrt' , self .filter_state .n )
134+
135+ def predict_identity (self , d_sys ):
136+ if isinstance (d_sys , AbstractCircularDistribution ):
137+ if not isinstance (d_sys , CircularFourierDistribution ):
138+ print ("Warning: d_sys is not a FourierDistribution. Transforming with a number of coefficients that is equal to that of the filter. For non-varying noises, transforming once is much more efficient and should be preferred." )
139+ d_sys = CircularFourierDistribution .from_distribution (d_sys , 2 * len (self .filter_state .a ) - 1 , self .filter_state .transformation )
140+ self .filter_state = self .filter_state .convolve (d_sys )
141+
142+ elif isinstance (d_sys , np .ndarray ):
143+ no_coeffs = 2 * len (self .filter_state .a ) - 1
144+ assert self .filter_state .transformation == 'sqrt' , "Only sqrt transformation currently supported"
145+ assert d_sys .size == no_coeffs , "Assume that as many grid points are used as there are coefficients."
146+ fdvals = np .fft .ifftshift (self .filter_state .c ) ** 2
147+ f_pred_vals = fftconvolve (fdvals , d_sys , mode = 'same' ) * no_coeffs * 2 * np .pi
148+ self .filter_state = CircularFourierDistribution (transformation = 'sqrt' , c = np .fft .fftshift (np .fft .fft (np .sqrt (f_pred_vals ))) / no_coeffs )
149+
150+ else :
151+ raise ValueError ("Input format of d_sys is not supported" )
152+
153+ @beartype
154+ def update_identity (self , d_meas : AbstractCircularDistribution , z ):
155+ if not isinstance (d_meas , CircularFourierDistribution ):
156+ print ("Warning: d_meas is not a FourierDistribution. Transforming with a number of coefficients that is equal to that of the filter. For non-varying noises, transforming once is much more efficient and should be preferred." )
157+ d_meas = CircularFourierDistribution .from_distribution (d_meas , 2 * len (self .filter_state .a ) - 1 , self .filter_state .transformation )
158+ d_meas_shifted = d_meas .shift (z )
159+ self .filter_state = self .filter_state .multiply (d_meas_shifted , 2 * len (self .filter_state .a ) - 1 )
160+
161+ @beartype
162+ def predict_nonlinear (self , f : Callable , noise_distribution : AbstractCircularDistribution , truncate_joint_sqrt : bool = True ):
163+ f_trans = lambda xkk , xk : np .reshape (noise_distribution .pdf (xkk .T - f (xk .T )), xk .shape )
164+ self .predict_nonlinear_via_transition_density (f_trans , truncate_joint_sqrt )
165+
166+
167+ def update_nonlinear (self , likelihood , z ):
168+ fd_meas = CircularFourierDistribution .from_function (
169+ lambda x : likelihood (z , x .ravel ()).reshape (x .shape ),
170+ 2 * len (self .filter_state .a ) - 1 ,
171+ self .filter_state .transformation
172+ )
173+ self .update_identity (fd_meas , np .zeros_like (z ))
174+
175+ def update_nonlinear_via_ifft (self , likelihood , z ):
176+ c_curr = self .filter_state .c
177+ prior_vals = ifft (ifftshift (c_curr ), overwrite_x = True , workers = - 1 ) * len (c_curr )
178+ x_vals = np .linspace (0 , 2 * np .pi , len (c_curr ) + 1 )
179+
180+ if self .filter_state .transformation == 'identity' :
181+ posterior_vals = prior_vals * likelihood (z , x_vals [:- 1 ])
182+ elif self .filter_state .transformation == 'sqrt' :
183+ posterior_vals = prior_vals * np .sqrt (likelihood (z , x_vals [:- 1 ]))
184+ else :
185+ raise ValueError ('Transformation currently not supported' )
186+
187+ self .filter_state = CircularFourierDistribution .from_complex (fftshift (fft (posterior_vals , overwrite_x = True , workers = - 1 )), self .filter_state .transformation )
188+
189+ def association_likelihood (self , likelihood ):
190+ assert len (self .get_estimate .a ) == len (likelihood .a )
191+ assert len (self .get_estimate .transformation ) == len (likelihood .transformation )
192+
193+ if self .get_estimate .transformation == 'identity' :
194+ likelihood_val = 2 * np .pi * np .real (np .dot (self .get_estimate .c , likelihood .c .T ))
195+ elif self .get_estimate .transformation == 'sqrt' :
196+ likelihood_val = 2 * np .pi * np .linalg .norm (np .convolve (self .get_estimate .c , likelihood .c ))** 2
197+ else :
198+ raise ValueError ('Transformation not supported' )
199+
200+ return likelihood_val
0 commit comments