Skip to content

Commit 2f4e4c9

Browse files
committed
Fixed pdf of partially wrapped normal distribution
1 parent faf8e3d commit 2f4e4c9

2 files changed

Lines changed: 188 additions & 47 deletions

File tree

pyrecest/distributions/cart_prod/partially_wrapped_normal_distribution.py

Lines changed: 93 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import copy
2+
from itertools import product
23
from math import pi
34
from typing import Union
45

5-
# pylint: disable=redefined-builtin,no-name-in-module,no-member
66
# pylint: disable=no-name-in-module,no-member
7+
import pyrecest.backend
8+
9+
# pylint: disable=redefined-builtin,no-name-in-module,no-member
710
from pyrecest.backend import (
811
allclose,
912
arange,
@@ -12,21 +15,21 @@
1215
concatenate,
1316
cos,
1417
diag,
18+
empty,
1519
exp,
1620
hstack,
1721
int32,
1822
int64,
1923
linalg,
20-
meshgrid,
2124
mod,
2225
ndim,
2326
random,
2427
repeat,
2528
sin,
2629
stack,
2730
sum,
28-
tile,
2931
where,
32+
zeros,
3033
)
3134
from scipy.stats import multivariate_normal
3235

@@ -56,49 +59,95 @@ def __init__(self, mu, C, bound_dim: Union[int, int32, int64]):
5659
self.mu = where(arange(mu.shape[0]) < bound_dim, mod(mu, 2.0 * pi), mu)
5760
self.C = C
5861

59-
def pdf(self, xs, m: Union[int, int32, int64] = 3):
60-
xs = atleast_2d(xs)
61-
condition = (
62-
arange(xs.shape[1]) < self.bound_dim
63-
) # Create a condition based on column indices
64-
xs = where(
65-
# Broadcast the condition to match the shape of xs
66-
condition[None, :], # noqa: E203
67-
mod(xs, 2.0 * pi), # Compute the modulus where the condition is True
68-
xs, # Keep the original values where the condition is False
69-
)
70-
71-
assert xs.shape[-1] == self.input_dim
72-
73-
# generate multiples for wrapping
74-
multiples = array(range(-m, m + 1)) * 2.0 * pi
75-
76-
# create meshgrid for all combinations of multiples
77-
mesh = array(meshgrid(*[multiples] * self.bound_dim)).reshape(
78-
-1, self.bound_dim
79-
)
80-
81-
# reshape xs for broadcasting
82-
xs_reshaped = tile(xs[:, : self.bound_dim], (mesh.shape[0], 1)) # noqa: E203
83-
84-
# prepare data for wrapping (not applied to linear dimensions)
85-
xs_wrapped = xs_reshaped + repeat(mesh, xs.shape[0], axis=0)
86-
xs_wrapped = concatenate(
87-
[
88-
xs_wrapped,
89-
tile(xs[:, self.bound_dim :], (mesh.shape[0], 1)), # noqa: E203
90-
],
91-
axis=1,
92-
)
62+
# pylint: disable=too-many-locals
63+
def pdf(self, xs, m=3):
64+
"""
65+
Evaluate the PDF of the Hypercylindrical Wrapped Normal Distribution at given points.
9366
94-
# evaluate normal for all xs_wrapped
95-
mvn = multivariate_normal(self.mu, self.C)
96-
evals = array(mvn.pdf(xs_wrapped)) # For being compatible with all backends
67+
Parameters:
68+
xs (array-like): Input points of shape (n, d), where d = bound_dim + lin_dim.
69+
m (int, optional): Number of summands in each direction for wrapping. Default is 3.
9770
98-
# sum evaluations for the wrapped dimensions
99-
summed_evals = sum(evals.reshape(-1, (2 * m + 1) ** self.bound_dim), axis=1)
71+
Returns:
72+
p (ndarray): PDF values at each input point of shape (n,).
73+
"""
74+
assert (
75+
pyrecest.backend.__backend_name__ == "numpy"
76+
), "Only supported for numpy backend"
10077

101-
return summed_evals
78+
xs = atleast_2d(xs) # Ensure xs is 2D
79+
n, d = xs.shape
80+
assert (
81+
d == self.dim
82+
), f"Input dimensionality {d} does not match distribution dimensionality {self.dim}."
83+
84+
# Initialize the PDF values array
85+
p = zeros(n)
86+
87+
# Define batch size to manage memory usage
88+
batch_size = 1000
89+
90+
# Generate all possible offset combinations for periodic dimensions
91+
multiples = arange(-m, m + 1) * 2.0 * pi
92+
offset_combinations = list(
93+
product(multiples, repeat=self.bound_dim)
94+
) # Total combinations: (2m+1)^bound_dim
95+
num_offsets = len(offset_combinations)
96+
97+
# Pre-convert offset combinations to a NumPy array for efficient computation
98+
offset_array = array(offset_combinations) # Shape: (num_offsets, bound_dim)
99+
100+
# Process input data in batches
101+
for start in range(0, n, batch_size):
102+
end = min(start + batch_size, n)
103+
batch = xs[start:end] # Shape: (batch_size, d)
104+
105+
# Wrap periodic dimensions using modulus
106+
batch_wrapped = batch.copy()
107+
if self.bound_dim > 0:
108+
batch_wrapped[:, : self.bound_dim] = mod(
109+
batch_wrapped[:, : self.bound_dim], 2.0 * pi
110+
) # noqa: E203
111+
112+
if self.bound_dim > 0:
113+
# Correct broadcasting: batch_wrapped becomes (batch_size, 1, bound_dim)
114+
# offset_array becomes (1, num_offsets, bound_dim)
115+
wrapped_periodic = batch_wrapped[:, :self.bound_dim][:, None, :] + offset_array[None, :, :]
116+
# Now wrapped_periodic has shape (batch_size, num_offsets, bound_dim)
117+
wrapped_periodic = wrapped_periodic.reshape(-1, self.bound_dim)
118+
else:
119+
wrapped_periodic = empty((0, 0)) # No periodic dimensions
120+
121+
# Repeat linear dimensions for each offset
122+
if self.lin_dim > 0:
123+
linear_part = repeat(
124+
batch_wrapped[:, self.bound_dim :], # noqa: E203
125+
num_offsets,
126+
axis=0,
127+
) # Shape: (batch_size * num_offsets, lin_dim)
128+
# Concatenate wrapped periodic and linear parts
129+
if self.bound_dim > 0:
130+
wrapped_points = hstack(
131+
(wrapped_periodic, linear_part)
132+
) # Shape: (batch_size * num_offsets, d)
133+
else:
134+
wrapped_points = linear_part # Shape: (batch_size * num_offsets, d)
135+
else:
136+
wrapped_points = (
137+
wrapped_periodic # Shape: (batch_size * num_offsets, d)
138+
)
139+
140+
mvn = multivariate_normal(mean=self.mu, cov=self.C)
141+
# Evaluate the multivariate normal PDF at all wrapped points
142+
pdf_vals = mvn.pdf(wrapped_points) # Shape: (batch_size * num_offsets,)
143+
144+
# Reshape and sum the PDF values for each original point
145+
pdf_vals = pdf_vals.reshape(
146+
end - start, num_offsets
147+
) # Shape: (batch_size, num_offsets)
148+
p[start:end] = sum(pdf_vals, axis=1) # Shape: (batch_size,)
149+
150+
return p
102151

103152
def mode(self):
104153
"""
@@ -148,7 +197,7 @@ def sample(self, n: int):
148197
"""
149198
assert n > 0, "n must be positive"
150199
s = random.multivariate_normal(mean=self.mu, cov=self.C, size=(n,))
151-
wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi)
200+
wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) # noqa: E203
152201
unbounded_values = s[:, self.bound_dim :] # noqa: E203
153202

154203
# Concatenate the modified section with the unmodified section

pyrecest/tests/distributions/test_partially_wrapped_normal_distribution.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import unittest
2+
from math import pi
23

34
import numpy.testing as npt
5+
6+
# pylint: disable=no-name-in-module,no-member
7+
import pyrecest.backend
48
import scipy.linalg
59

610
# pylint: disable=no-name-in-module,no-member
7-
from pyrecest.backend import array, ones
11+
from pyrecest.backend import array, column_stack, diag, linspace, meshgrid
812
from pyrecest.distributions.cart_prod.partially_wrapped_normal_distribution import (
913
PartiallyWrappedNormalDistribution,
1014
)
@@ -16,8 +20,96 @@ def setUp(self) -> None:
1620
self.C = array([[2.0, 1.0], [1.0, 1.0]])
1721
self.dist_2d = PartiallyWrappedNormalDistribution(self.mu, self.C, 1)
1822

19-
def test_pdf(self):
20-
self.assertEqual(self.dist_2d.pdf(ones((10, 2))).shape, (10,))
23+
@unittest.skipIf(
24+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
25+
reason="Not supported on this backend",
26+
)
27+
def test_pdf_2d(self):
28+
expected_vals = array(
29+
[ # From Matlab implementation
30+
0.00719442236938856,
31+
0.0251110014500013,
32+
0.0531599904868136,
33+
0.0682587789359472,
34+
0.0531599904868136,
35+
0.0100784602259792,
36+
0.0351772826718058,
37+
0.0744703080006016,
38+
0.0956217682613369,
39+
0.0744703080006016,
40+
0.00119956714181477,
41+
0.00418690072543581,
42+
0.00886366890530323,
43+
0.0113811761595142,
44+
0.00886366890530323,
45+
0.000447592726560109,
46+
0.00156225212096022,
47+
0.00330728776602597,
48+
0.00424664155187776,
49+
0.00330728776602597,
50+
0.00719442236938856,
51+
0.0251110014500013,
52+
0.0531599904868136,
53+
0.0682587789359472,
54+
0.0531599904868136,
55+
]
56+
)
57+
58+
hwn = PartiallyWrappedNormalDistribution(
59+
array([1.0, 2.0]), diag(array([1.0, 2.0])), 1
60+
)
61+
x, y = meshgrid(linspace(0.0, 2.0 * pi, 5), linspace(-1.0, 3.0, 5))
62+
points = column_stack([x.T.ravel(), y.T.ravel()])
63+
npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7)
64+
65+
@unittest.skipIf(
66+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
67+
reason="Not supported on this backend",
68+
)
69+
def test_pdf_3d(self):
70+
expected_vals = array(
71+
[
72+
1.385492786310657e-07,
73+
1.473370096411339e-05,
74+
4.130095787341451e-04,
75+
2.310111272798270e-08,
76+
2.456634132168293e-06,
77+
6.886344649603137e-05,
78+
1.385492786310657e-07,
79+
1.473370096411339e-05,
80+
4.130095787341451e-04,
81+
2.650620509537817e-07,
82+
2.818740764494994e-05,
83+
7.901388378523353e-04,
84+
4.419530999723758e-08,
85+
4.699847505157542e-06,
86+
1.317443623260505e-04,
87+
2.650620509537817e-07,
88+
2.818740764494994e-05,
89+
7.901388378523353e-04,
90+
1.385492786310657e-07,
91+
1.473370096411339e-05,
92+
4.130095787341451e-04,
93+
2.310111272798270e-08,
94+
2.456634132168293e-06,
95+
6.886344649603137e-05,
96+
1.385492786310657e-07,
97+
1.473370096411339e-05,
98+
4.130095787341451e-04,
99+
]
100+
)
101+
102+
hwn = PartiallyWrappedNormalDistribution(
103+
array([1.0, 2.0, 7.0]), diag(array([1.0, 2.0, 3.0])), 2
104+
)
105+
x, y, z = meshgrid(
106+
linspace(0.0, 2.0 * pi, 3),
107+
linspace(0.0, 2.0 * pi, 3),
108+
linspace(-1.0, 3.0, 3),
109+
)
110+
points = column_stack([x.ravel(), y.ravel(), z.ravel()])
111+
npt.assert_allclose(hwn.pdf(points), expected_vals, atol=1e-7)
112+
21113

22114
def test_hybrid_mean_2d(self):
23115
npt.assert_allclose(self.dist_2d.hybrid_mean(), self.mu)

0 commit comments

Comments
 (0)