diff --git a/src/pyrecest/filters/gaussian_hypothesis_mixture.py b/src/pyrecest/filters/gaussian_hypothesis_mixture.py index 7bd1ff478..e9ba1b1c2 100644 --- a/src/pyrecest/filters/gaussian_hypothesis_mixture.py +++ b/src/pyrecest/filters/gaussian_hypothesis_mixture.py @@ -59,6 +59,8 @@ def normalize_log_weights(log_weights: list[float] | np.ndarray) -> np.ndarray: values = np.asarray(log_weights, dtype=float).reshape(-1) if values.size == 0: raise ValueError("log_weights must not be empty") + if np.any(np.isnan(values)): + raise ValueError("log_weights must not contain NaN values") positive_infinite = np.isposinf(values) if np.any(positive_infinite): diff --git a/tests/filters/test_gaussian_hypothesis_mixture.py b/tests/filters/test_gaussian_hypothesis_mixture.py index 6af41b986..592865bb4 100644 --- a/tests/filters/test_gaussian_hypothesis_mixture.py +++ b/tests/filters/test_gaussian_hypothesis_mixture.py @@ -24,6 +24,22 @@ def test_multiple_positive_infinite_log_weights_share_mass(self): self.assertTrue(np.allclose(weights, np.array([0.5, 0.0, 0.5]))) + def test_nan_log_weights_are_rejected(self): + with self.assertRaisesRegex(ValueError, "NaN"): + normalize_log_weights(np.array([0.0, np.nan])) + + with self.assertRaisesRegex(ValueError, "NaN"): + moment_match_gaussian_hypotheses( + [ + WeightedGaussianHypothesis( + np.array([0.0]), np.array([[1.0]]), log_weight=0.0 + ), + WeightedGaussianHypothesis( + np.array([1.0]), np.array([[1.0]]), log_weight=np.nan + ), + ] + ) + def test_moment_matching_respects_dominant_infinite_weight(self): mean, covariance, weights = moment_match_gaussian_hypotheses( [