1+ import numpy as np
2+ import unittest
3+ from itertools import product
4+ from pyrecest .distributions .se2_partially_wrapped_normal_distribution import SE2PartiallyWrappedNormalDistribution
5+
6+ from scipy .stats import multivariate_normal
7+
8+ class SE2PWNDistributionTest (unittest .TestCase ):
9+
10+ def setUp (self ):
11+ self .mu = np .array ([2 , 3 , 4 ])
12+ self .si1 , self .si2 , self .si3 = 0.9 , 1.5 , 1.7
13+ self .rho12 , self .rho13 , self .rho23 = 0.5 , 0.3 , 0.4
14+ self .C = np .array ([
15+ [self .si1 ** 2 , self .si1 * self .si2 * self .rho12 , self .si1 * self .si3 * self .rho13 ],
16+ [self .si1 * self .si2 * self .rho12 , self .si2 ** 2 , self .si2 * self .si3 * self .rho23 ],
17+ [self .si1 * self .si3 * self .rho13 , self .si2 * self .si3 * self .rho23 , self .si3 ** 2 ]
18+ ])
19+ self .pwn = SE2PartiallyWrappedNormalDistribution (self .mu , self .C )
20+
21+ @staticmethod
22+ def _loop_wrapped_pdf (x , mu , C , n_wrappings = 10 ):
23+ bound_dim = 1
24+ # Ensure x is at least 2D for iteration
25+ x = np .array (np .atleast_2d (x ), dtype = np .float64 )
26+
27+ n_samples = x .shape [0 ]
28+ results = np .zeros (n_samples )
29+
30+ # Generate all combinations of offsets for the bound_dim dimensions
31+ offset_values = [i * 2 * np .pi for i in range (- n_wrappings , n_wrappings + 1 )]
32+ all_combinations = list (product (offset_values , repeat = bound_dim ))
33+
34+ # Iterate over each sample
35+ for i in range (n_samples ):
36+ sample = x [i ]
37+ p = 0
38+ # Iterate over each offset combination and add to the sample before evaluating the PDF
39+ for offset in all_combinations :
40+ shifted_sample = sample .copy ()
41+ shifted_sample [:bound_dim ] += np .array (offset )
42+ p += multivariate_normal .pdf (shifted_sample , mu , C )
43+ results [i ] = p
44+
45+ # If input was 1D, return a single value; otherwise, return the array
46+ return results [0 ] if x .shape [0 ] == 1 else results
47+
48+ def test_pdf (self ):
49+ self .assertAlmostEqual (self .pwn .pdf (self .mu ), SE2PWNDistributionTest ._loop_wrapped_pdf (self .mu , self .mu , self .C ), places = 10 )
50+ self .assertAlmostEqual (self .pwn .pdf (self .mu - 1 ), SE2PWNDistributionTest ._loop_wrapped_pdf (self .mu - 1 , self .mu , self .C ), places = 10 )
51+ self .assertAlmostEqual (self .pwn .pdf (self .mu + 2 ), SE2PWNDistributionTest ._loop_wrapped_pdf (self .mu + 2 , self .mu , self .C ), places = 10 )
52+ x = np .random .rand (20 , 3 )
53+ np .testing .assert_allclose (self .pwn .pdf (x ), SE2PWNDistributionTest ._loop_wrapped_pdf (x , self .mu , self .C , n_wrappings = 10 ), rtol = 1e-10 )
54+
55+ def test_pdf_large_uncertainty (self ):
56+ C_high = 100 * np .eye (3 , 3 )
57+ pwn_large_uncertainty = SE2PartiallyWrappedNormalDistribution (self .mu , C_high )
58+ for t in range (1 , 7 ):
59+ # Verify they are equal for 3 wrappings (same number of wrappings as in the class)
60+ pdf_class = pwn_large_uncertainty .pdf (self .mu + np .array ([t , 0 , 0 ]))
61+ np .testing .assert_allclose (pdf_class ,
62+ SE2PWNDistributionTest ._loop_wrapped_pdf (self .mu + np .array ([t , 0 , 0 ]), self .mu , C_high , n_wrappings = 3 ),
63+ rtol = 0.00001 )
64+
65+ # Verify they are unequal for 10 wrappings when the covariance is high
66+ pdf_loop_nested_10 = SE2PWNDistributionTest ._loop_wrapped_pdf (self .mu + np .array ([t , 0 , 0 ]), self .mu , C_high , n_wrappings = 10 )
67+
68+ # Calculate the relative errors
69+ relative_errors = np .abs (pdf_class - pdf_loop_nested_10 ) / pdf_class
70+ # Find the maximum relative error
71+ max_relative_error = np .max (relative_errors )
72+ self .assertGreater (max_relative_error , 0.00001 )
73+
74+ def test_integral (self ):
75+ self .assertAlmostEqual (self .pwn .integrate (), 1 , places = 5 )
76+
77+ def test_sampling_basic (self ):
78+ np .random .seed (0 )
79+ n = 10
80+ s = self .pwn .sample (n )
81+ self .assertEqual (s .shape [0 ], n )
82+ self .assertEqual (s .shape [1 ], 3 )
83+ s = s [:, 0 ]
84+ self .assertTrue (np .all (s >= 0 ))
85+ self .assertTrue (np .all (s < 2 * np .pi ))
86+
87+
88+ if __name__ == '__main__' :
89+ unittest .main ()
0 commit comments