Skip to content

Commit 84259bb

Browse files
committed
Added partially wrapped normal distribution
1 parent 616a605 commit 84259bb

2 files changed

Lines changed: 97 additions & 0 deletions

File tree

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .abstract_se2_distribution import AbstractSE2Distribution
2+
from .cart_prod.partially_wrapped_normal_distribution import PartiallyWrappedNormalDistribution
3+
4+
class SE2PartiallyWrappedNormalDistribution(PartiallyWrappedNormalDistribution, AbstractSE2Distribution):
5+
6+
def __init__(self, mu, C):
7+
AbstractSE2Distribution.__init__(self)
8+
PartiallyWrappedNormalDistribution.__init__(self, mu, C, bound_dim=self.bound_dim)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import unittest
3+
from itertools import product
4+
from pyrecest.distributions.se2_partially_wrapped_normal_distribution import SE2PartiallyWrappedNormalDistribution
5+
6+
from scipy.stats import multivariate_normal
7+
8+
class SE2PWNDistributionTest(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.mu = np.array([2, 3, 4])
12+
self.si1, self.si2, self.si3 = 0.9, 1.5, 1.7
13+
self.rho12, self.rho13, self.rho23 = 0.5, 0.3, 0.4
14+
self.C = np.array([
15+
[self.si1**2, self.si1*self.si2*self.rho12, self.si1*self.si3*self.rho13],
16+
[self.si1*self.si2*self.rho12, self.si2**2, self.si2*self.si3*self.rho23],
17+
[self.si1*self.si3*self.rho13, self.si2*self.si3*self.rho23, self.si3**2]
18+
])
19+
self.pwn = SE2PartiallyWrappedNormalDistribution(self.mu, self.C)
20+
21+
@staticmethod
22+
def _loop_wrapped_pdf(x, mu, C, n_wrappings=10):
23+
bound_dim = 1
24+
# Ensure x is at least 2D for iteration
25+
x = np.array(np.atleast_2d(x), dtype=np.float64)
26+
27+
n_samples = x.shape[0]
28+
results = np.zeros(n_samples)
29+
30+
# Generate all combinations of offsets for the bound_dim dimensions
31+
offset_values = [i*2*np.pi for i in range(-n_wrappings, n_wrappings+1)]
32+
all_combinations = list(product(offset_values, repeat=bound_dim))
33+
34+
# Iterate over each sample
35+
for i in range(n_samples):
36+
sample = x[i]
37+
p = 0
38+
# Iterate over each offset combination and add to the sample before evaluating the PDF
39+
for offset in all_combinations:
40+
shifted_sample = sample.copy()
41+
shifted_sample[:bound_dim] += np.array(offset)
42+
p += multivariate_normal.pdf(shifted_sample, mu, C)
43+
results[i] = p
44+
45+
# If input was 1D, return a single value; otherwise, return the array
46+
return results[0] if x.shape[0] == 1 else results
47+
48+
def test_pdf(self):
49+
self.assertAlmostEqual(self.pwn.pdf(self.mu), SE2PWNDistributionTest._loop_wrapped_pdf(self.mu, self.mu, self.C), places=10)
50+
self.assertAlmostEqual(self.pwn.pdf(self.mu-1), SE2PWNDistributionTest._loop_wrapped_pdf(self.mu-1, self.mu, self.C), places=10)
51+
self.assertAlmostEqual(self.pwn.pdf(self.mu+2), SE2PWNDistributionTest._loop_wrapped_pdf(self.mu+2, self.mu, self.C), places=10)
52+
x = np.random.rand(20, 3)
53+
np.testing.assert_allclose(self.pwn.pdf(x), SE2PWNDistributionTest._loop_wrapped_pdf(x, self.mu, self.C, n_wrappings=10), rtol=1e-10)
54+
55+
def test_pdf_large_uncertainty(self):
56+
C_high = 100 * np.eye(3, 3)
57+
pwn_large_uncertainty = SE2PartiallyWrappedNormalDistribution(self.mu, C_high)
58+
for t in range(1, 7):
59+
# Verify they are equal for 3 wrappings (same number of wrappings as in the class)
60+
pdf_class = pwn_large_uncertainty.pdf(self.mu + np.array([t, 0, 0]))
61+
np.testing.assert_allclose(pdf_class,
62+
SE2PWNDistributionTest._loop_wrapped_pdf(self.mu + np.array([t, 0, 0]), self.mu, C_high, n_wrappings=3),
63+
rtol=0.00001)
64+
65+
# Verify they are unequal for 10 wrappings when the covariance is high
66+
pdf_loop_nested_10 = SE2PWNDistributionTest._loop_wrapped_pdf(self.mu + np.array([t, 0, 0]), self.mu, C_high, n_wrappings=10)
67+
68+
# Calculate the relative errors
69+
relative_errors = np.abs(pdf_class - pdf_loop_nested_10) / pdf_class
70+
# Find the maximum relative error
71+
max_relative_error = np.max(relative_errors)
72+
self.assertGreater(max_relative_error, 0.00001)
73+
74+
def test_integral(self):
75+
self.assertAlmostEqual(self.pwn.integrate(), 1, places=5)
76+
77+
def test_sampling_basic(self):
78+
np.random.seed(0)
79+
n = 10
80+
s = self.pwn.sample(n)
81+
self.assertEqual(s.shape[0], n)
82+
self.assertEqual(s.shape[1], 3)
83+
s = s[:, 0]
84+
self.assertTrue(np.all(s >= 0))
85+
self.assertTrue(np.all(s < 2 * np.pi))
86+
87+
88+
if __name__ == '__main__':
89+
unittest.main()

0 commit comments

Comments
 (0)