Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/pyrecest/_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def _normalize_reduction_axes(axis, ndim_value):
axes = tuple(axis)

normalized_axes = tuple(
axis_index + ndim_value if axis_index < 0 else axis_index
for axis_index in axes
axis_index + ndim_value if axis_index < 0 else axis_index for axis_index in axes
)
if len(set(normalized_axes)) != len(normalized_axes):
raise ValueError("duplicate value in 'axis'")
Expand Down
6 changes: 1 addition & 5 deletions src/pyrecest/_backend/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,7 @@ def _validate_choice_probabilities(p, population_size):
raise ValueError("p must be 1-dimensional with one entry per population item")

p_sum = p.sum()
if (
bool(_jnp.any(p < 0))
or not bool(_jnp.isfinite(p_sum))
or bool(p_sum <= 0)
):
if bool(_jnp.any(p < 0)) or not bool(_jnp.isfinite(p_sum)) or bool(p_sum <= 0):
raise ValueError("probabilities do not sum to a positive value")
return p / p_sum

Expand Down
4 changes: 3 additions & 1 deletion src/pyrecest/_backend/pytorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def solve_sylvester(a, b, q):
a = a.to(dtype=common_dtype)
b = b.to(dtype=common_dtype)
q = q.to(dtype=common_dtype)
is_shared_factor = a.shape == b.shape and _torch.allclose(a, b, atol=1e-6, rtol=1e-6)
is_shared_factor = a.shape == b.shape and _torch.allclose(
a, b, atol=1e-6, rtol=1e-6
)
is_shared_hermitian_factor = is_shared_factor and _torch.all(
_torch.abs(a - a.transpose(-2, -1).conj()) < 1e-6
)
Expand Down
6 changes: 3 additions & 3 deletions src/pyrecest/evaluation/check_and_fix_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def _expand_meas_per_step(simulation_param):

def _validate_measurement_counts(simulation_param):
counts = simulation_param["n_meas_at_individual_time_step"]
assert len(counts) == simulation_param["n_timesteps"], (
"n_meas_at_individual_time_step must have one entry per time step"
)
assert (
len(counts) == simulation_param["n_timesteps"]
), "n_meas_at_individual_time_step must have one entry per time step"
assert all(
x > 0 for x in counts
), "n_meas_at_individual_time_step must contain positive values"
Expand Down
22 changes: 13 additions & 9 deletions src/pyrecest/filters/tracklet_viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,7 @@ def _fixed_lag_committed_step_cost(
return float(
unary_cost
+ config.missed_detection_cost
+ (
config.consecutive_miss_cost
if committed_miss_streak > 0
else 0.0
)
+ (config.consecutive_miss_cost if committed_miss_streak > 0 else 0.0)
)
return float(unary_cost)

Expand All @@ -543,7 +539,9 @@ def _fixed_lag_committed_step_cost(
config,
)
)
return float(unary_cost + transition(previous_committed, selected, committed_miss_streak))
return float(
unary_cost + transition(previous_committed, selected, committed_miss_streak)
)


def _reconstruct_path(
Expand Down Expand Up @@ -591,12 +589,15 @@ def _motion_cost(
else:
predicted = (
previous_position
+ np.asarray(previous.velocity, dtype=float).reshape(previous_position.shape)
+ np.asarray(previous.velocity, dtype=float).reshape(
previous_position.shape
)
* dt_s
)
position_cost = float(
np.sum(
((current_position - predicted) / float(config.transition_position_std)) ** 2
((current_position - predicted) / float(config.transition_position_std))
** 2
)
)
speed_cost = 0.0
Expand All @@ -614,7 +615,10 @@ def _motion_cost(
)
velocity_cost = float(
np.sum(
((velocity - displacement_velocity) / float(config.transition_velocity_std))
(
(velocity - displacement_velocity)
/ float(config.transition_velocity_std)
)
** 2
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code


Expand Down
1 change: 0 additions & 1 deletion tests/backend_support/test_pytorch_random_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from tests.support.backend_runner import run_backend_code


Expand Down
1 change: 0 additions & 1 deletion tests/distributions/test_toroidal_dirac_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import numpy.testing as npt

import pyrecest.backend
from pyrecest.backend import array
from pyrecest.distributions import ToroidalDiracDistribution
Expand Down
1 change: 0 additions & 1 deletion tests/evaluation/test_check_and_fix_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest

from pyrecest.evaluation.check_and_fix_config import (
_expand_meas_per_step,
_validate_measurement_counts,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_backend_dot_contract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np
import numpy.testing as npt
import pytest

import pyrecest.backend as backend
import pytest
from pyrecest.backend import array


Expand Down
6 changes: 2 additions & 4 deletions tests/test_jax_choice_probability_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
@unittest.skipIf(importlib.util.find_spec("jax") is None, "JAX is not installed")
class JaxChoiceProbabilityValidationTest(unittest.TestCase):
def assert_choice_raises_value_error(self, call_source):
code = textwrap.dedent(
f"""
code = textwrap.dedent(f"""
from pyrecest.backend import array, random

try:
Expand All @@ -22,8 +21,7 @@ def assert_choice_raises_value_error(self, call_source):
) from exc
else:
raise AssertionError("expected ValueError")
"""
)
""")
result = run_backend_code("jax", code)
self.assertEqual(result.returncode, 0, result.stdout + result.stderr)

Expand Down
Loading