|
1 | 1 | import copy |
| 2 | +from itertools import product |
2 | 3 | from math import pi |
3 | 4 | from typing import Union |
4 | 5 |
|
5 | | -# pylint: disable=redefined-builtin,no-name-in-module,no-member |
6 | 6 | # pylint: disable=no-name-in-module,no-member |
| 7 | +import pyrecest.backend |
| 8 | + |
| 9 | +# pylint: disable=redefined-builtin,no-name-in-module,no-member |
7 | 10 | from pyrecest.backend import ( |
8 | 11 | allclose, |
9 | 12 | arange, |
|
12 | 15 | concatenate, |
13 | 16 | cos, |
14 | 17 | diag, |
| 18 | + empty, |
15 | 19 | exp, |
16 | 20 | hstack, |
17 | 21 | int32, |
18 | 22 | int64, |
19 | 23 | linalg, |
20 | | - meshgrid, |
21 | 24 | mod, |
22 | 25 | ndim, |
23 | 26 | random, |
24 | 27 | repeat, |
25 | 28 | sin, |
26 | 29 | stack, |
27 | 30 | sum, |
28 | | - tile, |
29 | 31 | where, |
| 32 | + zeros, |
30 | 33 | ) |
31 | 34 | from scipy.stats import multivariate_normal |
32 | 35 |
|
@@ -56,49 +59,95 @@ def __init__(self, mu, C, bound_dim: Union[int, int32, int64]): |
56 | 59 | self.mu = where(arange(mu.shape[0]) < bound_dim, mod(mu, 2.0 * pi), mu) |
57 | 60 | self.C = C |
58 | 61 |
|
59 | | - def pdf(self, xs, m: Union[int, int32, int64] = 3): |
60 | | - xs = atleast_2d(xs) |
61 | | - condition = ( |
62 | | - arange(xs.shape[1]) < self.bound_dim |
63 | | - ) # Create a condition based on column indices |
64 | | - xs = where( |
65 | | - # Broadcast the condition to match the shape of xs |
66 | | - condition[None, :], # noqa: E203 |
67 | | - mod(xs, 2.0 * pi), # Compute the modulus where the condition is True |
68 | | - xs, # Keep the original values where the condition is False |
69 | | - ) |
70 | | - |
71 | | - assert xs.shape[-1] == self.input_dim |
72 | | - |
73 | | - # generate multiples for wrapping |
74 | | - multiples = array(range(-m, m + 1)) * 2.0 * pi |
75 | | - |
76 | | - # create meshgrid for all combinations of multiples |
77 | | - mesh = array(meshgrid(*[multiples] * self.bound_dim)).reshape( |
78 | | - -1, self.bound_dim |
79 | | - ) |
80 | | - |
81 | | - # reshape xs for broadcasting |
82 | | - xs_reshaped = tile(xs[:, : self.bound_dim], (mesh.shape[0], 1)) # noqa: E203 |
83 | | - |
84 | | - # prepare data for wrapping (not applied to linear dimensions) |
85 | | - xs_wrapped = xs_reshaped + repeat(mesh, xs.shape[0], axis=0) |
86 | | - xs_wrapped = concatenate( |
87 | | - [ |
88 | | - xs_wrapped, |
89 | | - tile(xs[:, self.bound_dim :], (mesh.shape[0], 1)), # noqa: E203 |
90 | | - ], |
91 | | - axis=1, |
92 | | - ) |
| 62 | + # pylint: disable=too-many-locals |
| 63 | + def pdf(self, xs, m=3): |
| 64 | + """ |
| 65 | + Evaluate the PDF of the Hypercylindrical Wrapped Normal Distribution at given points. |
93 | 66 |
|
94 | | - # evaluate normal for all xs_wrapped |
95 | | - mvn = multivariate_normal(self.mu, self.C) |
96 | | - evals = array(mvn.pdf(xs_wrapped)) # For being compatible with all backends |
| 67 | + Parameters: |
| 68 | + xs (array-like): Input points of shape (n, d), where d = bound_dim + lin_dim. |
| 69 | + m (int, optional): Number of summands in each direction for wrapping. Default is 3. |
97 | 70 |
|
98 | | - # sum evaluations for the wrapped dimensions |
99 | | - summed_evals = sum(evals.reshape(-1, (2 * m + 1) ** self.bound_dim), axis=1) |
| 71 | + Returns: |
| 72 | + p (ndarray): PDF values at each input point of shape (n,). |
| 73 | + """ |
| 74 | + assert ( |
| 75 | + pyrecest.backend.__backend_name__ == "numpy" |
| 76 | + ), "Only supported for numpy backend" |
100 | 77 |
|
101 | | - return summed_evals |
| 78 | + xs = atleast_2d(xs) # Ensure xs is 2D |
| 79 | + n, d = xs.shape |
| 80 | + assert ( |
| 81 | + d == self.dim |
| 82 | + ), f"Input dimensionality {d} does not match distribution dimensionality {self.dim}." |
| 83 | + |
| 84 | + # Initialize the PDF values array |
| 85 | + p = zeros(n) |
| 86 | + |
| 87 | + # Define batch size to manage memory usage |
| 88 | + batch_size = 1000 |
| 89 | + |
| 90 | + # Generate all possible offset combinations for periodic dimensions |
| 91 | + multiples = arange(-m, m + 1) * 2.0 * pi |
| 92 | + offset_combinations = list( |
| 93 | + product(multiples, repeat=self.bound_dim) |
| 94 | + ) # Total combinations: (2m+1)^bound_dim |
| 95 | + num_offsets = len(offset_combinations) |
| 96 | + |
| 97 | + # Pre-convert offset combinations to a NumPy array for efficient computation |
| 98 | + offset_array = array(offset_combinations) # Shape: (num_offsets, bound_dim) |
| 99 | + |
| 100 | + # Process input data in batches |
| 101 | + for start in range(0, n, batch_size): |
| 102 | + end = min(start + batch_size, n) |
| 103 | + batch = xs[start:end] # Shape: (batch_size, d) |
| 104 | + |
| 105 | + # Wrap periodic dimensions using modulus |
| 106 | + batch_wrapped = batch.copy() |
| 107 | + if self.bound_dim > 0: |
| 108 | + batch_wrapped[:, : self.bound_dim] = mod( |
| 109 | + batch_wrapped[:, : self.bound_dim], 2.0 * pi |
| 110 | + ) # noqa: E203 |
| 111 | + |
| 112 | + if self.bound_dim > 0: |
| 113 | + # Correct broadcasting: batch_wrapped becomes (batch_size, 1, bound_dim) |
| 114 | + # offset_array becomes (1, num_offsets, bound_dim) |
| 115 | + wrapped_periodic = batch_wrapped[:, :self.bound_dim][:, None, :] + offset_array[None, :, :] |
| 116 | + # Now wrapped_periodic has shape (batch_size, num_offsets, bound_dim) |
| 117 | + wrapped_periodic = wrapped_periodic.reshape(-1, self.bound_dim) |
| 118 | + else: |
| 119 | + wrapped_periodic = empty((0, 0)) # No periodic dimensions |
| 120 | + |
| 121 | + # Repeat linear dimensions for each offset |
| 122 | + if self.lin_dim > 0: |
| 123 | + linear_part = repeat( |
| 124 | + batch_wrapped[:, self.bound_dim :], # noqa: E203 |
| 125 | + num_offsets, |
| 126 | + axis=0, |
| 127 | + ) # Shape: (batch_size * num_offsets, lin_dim) |
| 128 | + # Concatenate wrapped periodic and linear parts |
| 129 | + if self.bound_dim > 0: |
| 130 | + wrapped_points = hstack( |
| 131 | + (wrapped_periodic, linear_part) |
| 132 | + ) # Shape: (batch_size * num_offsets, d) |
| 133 | + else: |
| 134 | + wrapped_points = linear_part # Shape: (batch_size * num_offsets, d) |
| 135 | + else: |
| 136 | + wrapped_points = ( |
| 137 | + wrapped_periodic # Shape: (batch_size * num_offsets, d) |
| 138 | + ) |
| 139 | + |
| 140 | + mvn = multivariate_normal(mean=self.mu, cov=self.C) |
| 141 | + # Evaluate the multivariate normal PDF at all wrapped points |
| 142 | + pdf_vals = mvn.pdf(wrapped_points) # Shape: (batch_size * num_offsets,) |
| 143 | + |
| 144 | + # Reshape and sum the PDF values for each original point |
| 145 | + pdf_vals = pdf_vals.reshape( |
| 146 | + end - start, num_offsets |
| 147 | + ) # Shape: (batch_size, num_offsets) |
| 148 | + p[start:end] = sum(pdf_vals, axis=1) # Shape: (batch_size,) |
| 149 | + |
| 150 | + return p |
102 | 151 |
|
103 | 152 | def mode(self): |
104 | 153 | """ |
@@ -148,7 +197,7 @@ def sample(self, n: int): |
148 | 197 | """ |
149 | 198 | assert n > 0, "n must be positive" |
150 | 199 | s = random.multivariate_normal(mean=self.mu, cov=self.C, size=(n,)) |
151 | | - wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) |
| 200 | + wrapped_values = mod(s[:, : self.bound_dim], 2.0 * pi) # noqa: E203 |
152 | 201 | unbounded_values = s[:, self.bound_dim :] # noqa: E203 |
153 | 202 |
|
154 | 203 | # Concatenate the modified section with the unmodified section |
|
0 commit comments