From 4e7d9607498f229dc00b13fbca398499678fe80f Mon Sep 17 00:00:00 2001 From: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:48:05 +0000 Subject: [PATCH 1/2] [MegaLinter] Apply linters automatic fixes --- src/pyrecest/_backend/_common.py | 3 +-- src/pyrecest/_backend/jax/random.py | 6 +---- src/pyrecest/_backend/pytorch/linalg.py | 4 +++- .../evaluation/check_and_fix_config.py | 6 ++--- src/pyrecest/filters/tracklet_viterbi.py | 22 +++++++++++-------- ...ytorch_fractional_matrix_power_contract.py | 1 - .../test_pytorch_random_contract.py | 1 - .../test_toroidal_dirac_distribution.py | 1 - tests/evaluation/test_check_and_fix_config.py | 1 - tests/test_backend_dot_contract.py | 3 +-- .../test_jax_choice_probability_validation.py | 6 ++--- 11 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/pyrecest/_backend/_common.py b/src/pyrecest/_backend/_common.py index 880a6d3f2..d497c95bd 100644 --- a/src/pyrecest/_backend/_common.py +++ b/src/pyrecest/_backend/_common.py @@ -56,8 +56,7 @@ def _normalize_reduction_axes(axis, ndim_value): axes = tuple(axis) normalized_axes = tuple( - axis_index + ndim_value if axis_index < 0 else axis_index - for axis_index in axes + axis_index + ndim_value if axis_index < 0 else axis_index for axis_index in axes ) if len(set(normalized_axes)) != len(normalized_axes): raise ValueError("duplicate value in 'axis'") diff --git a/src/pyrecest/_backend/jax/random.py b/src/pyrecest/_backend/jax/random.py index 749588a1d..eefd97ea2 100644 --- a/src/pyrecest/_backend/jax/random.py +++ b/src/pyrecest/_backend/jax/random.py @@ -271,11 +271,7 @@ def _validate_choice_probabilities(p, population_size): raise ValueError("p must be 1-dimensional with one entry per population item") p_sum = p.sum() - if ( - bool(_jnp.any(p < 0)) - or not bool(_jnp.isfinite(p_sum)) - or bool(p_sum <= 0) - ): + if bool(_jnp.any(p < 0)) or not bool(_jnp.isfinite(p_sum)) or bool(p_sum <= 0): raise ValueError("probabilities do not sum to a positive value") return p / p_sum diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py index 6c7d8e4b1..5264de67d 100644 --- a/src/pyrecest/_backend/pytorch/linalg.py +++ b/src/pyrecest/_backend/pytorch/linalg.py @@ -182,7 +182,9 @@ def solve_sylvester(a, b, q): a = a.to(dtype=common_dtype) b = b.to(dtype=common_dtype) q = q.to(dtype=common_dtype) - is_shared_factor = a.shape == b.shape and _torch.allclose(a, b, atol=1e-6, rtol=1e-6) + is_shared_factor = a.shape == b.shape and _torch.allclose( + a, b, atol=1e-6, rtol=1e-6 + ) is_shared_hermitian_factor = is_shared_factor and _torch.all( _torch.abs(a - a.transpose(-2, -1).conj()) < 1e-6 ) diff --git a/src/pyrecest/evaluation/check_and_fix_config.py b/src/pyrecest/evaluation/check_and_fix_config.py index d8d3814ee..1c0f98f8d 100644 --- a/src/pyrecest/evaluation/check_and_fix_config.py +++ b/src/pyrecest/evaluation/check_and_fix_config.py @@ -20,9 +20,9 @@ def _expand_meas_per_step(simulation_param): def _validate_measurement_counts(simulation_param): counts = simulation_param["n_meas_at_individual_time_step"] - assert len(counts) == simulation_param["n_timesteps"], ( - "n_meas_at_individual_time_step must have one entry per time step" - ) + assert ( + len(counts) == simulation_param["n_timesteps"] + ), "n_meas_at_individual_time_step must have one entry per time step" assert all( x > 0 for x in counts ), "n_meas_at_individual_time_step must contain positive values" diff --git a/src/pyrecest/filters/tracklet_viterbi.py b/src/pyrecest/filters/tracklet_viterbi.py index d1c9f9e0e..48b3605ca 100644 --- a/src/pyrecest/filters/tracklet_viterbi.py +++ b/src/pyrecest/filters/tracklet_viterbi.py @@ -527,11 +527,7 @@ def _fixed_lag_committed_step_cost( return float( unary_cost + config.missed_detection_cost - + ( - config.consecutive_miss_cost - if committed_miss_streak > 0 - else 0.0 - ) + + (config.consecutive_miss_cost if committed_miss_streak > 0 else 0.0) ) return float(unary_cost) @@ -543,7 +539,9 @@ def _fixed_lag_committed_step_cost( config, ) ) - return float(unary_cost + transition(previous_committed, selected, committed_miss_streak)) + return float( + unary_cost + transition(previous_committed, selected, committed_miss_streak) + ) def _reconstruct_path( @@ -591,12 +589,15 @@ def _motion_cost( else: predicted = ( previous_position - + np.asarray(previous.velocity, dtype=float).reshape(previous_position.shape) + + np.asarray(previous.velocity, dtype=float).reshape( + previous_position.shape + ) * dt_s ) position_cost = float( np.sum( - ((current_position - predicted) / float(config.transition_position_std)) ** 2 + ((current_position - predicted) / float(config.transition_position_std)) + ** 2 ) ) speed_cost = 0.0 @@ -614,7 +615,10 @@ def _motion_cost( ) velocity_cost = float( np.sum( - ((velocity - displacement_velocity) / float(config.transition_velocity_std)) + ( + (velocity - displacement_velocity) + / float(config.transition_velocity_std) + ) ** 2 ) ) diff --git a/tests/backend_support/test_pytorch_fractional_matrix_power_contract.py b/tests/backend_support/test_pytorch_fractional_matrix_power_contract.py index 29ab3b950..d6612faa8 100644 --- a/tests/backend_support/test_pytorch_fractional_matrix_power_contract.py +++ b/tests/backend_support/test_pytorch_fractional_matrix_power_contract.py @@ -1,7 +1,6 @@ import importlib.util import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/backend_support/test_pytorch_random_contract.py b/tests/backend_support/test_pytorch_random_contract.py index ec9149891..cf17fd848 100644 --- a/tests/backend_support/test_pytorch_random_contract.py +++ b/tests/backend_support/test_pytorch_random_contract.py @@ -1,5 +1,4 @@ import pytest - from tests.support.backend_runner import run_backend_code diff --git a/tests/distributions/test_toroidal_dirac_distribution.py b/tests/distributions/test_toroidal_dirac_distribution.py index c6b45dbd0..1b5e5c4e6 100644 --- a/tests/distributions/test_toroidal_dirac_distribution.py +++ b/tests/distributions/test_toroidal_dirac_distribution.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt - import pyrecest.backend from pyrecest.backend import array from pyrecest.distributions import ToroidalDiracDistribution diff --git a/tests/evaluation/test_check_and_fix_config.py b/tests/evaluation/test_check_and_fix_config.py index 0de1314c7..0abe9d9e3 100644 --- a/tests/evaluation/test_check_and_fix_config.py +++ b/tests/evaluation/test_check_and_fix_config.py @@ -1,5 +1,4 @@ import pytest - from pyrecest.evaluation.check_and_fix_config import ( _expand_meas_per_step, _validate_measurement_counts, diff --git a/tests/test_backend_dot_contract.py b/tests/test_backend_dot_contract.py index a84d3a18f..90c21484c 100644 --- a/tests/test_backend_dot_contract.py +++ b/tests/test_backend_dot_contract.py @@ -1,8 +1,7 @@ import numpy as np import numpy.testing as npt -import pytest - import pyrecest.backend as backend +import pytest from pyrecest.backend import array diff --git a/tests/test_jax_choice_probability_validation.py b/tests/test_jax_choice_probability_validation.py index 86593e784..11076e307 100644 --- a/tests/test_jax_choice_probability_validation.py +++ b/tests/test_jax_choice_probability_validation.py @@ -8,8 +8,7 @@ @unittest.skipIf(importlib.util.find_spec("jax") is None, "JAX is not installed") class JaxChoiceProbabilityValidationTest(unittest.TestCase): def assert_choice_raises_value_error(self, call_source): - code = textwrap.dedent( - f""" + code = textwrap.dedent(f""" from pyrecest.backend import array, random try: @@ -22,8 +21,7 @@ def assert_choice_raises_value_error(self, call_source): ) from exc else: raise AssertionError("expected ValueError") - """ - ) + """) result = run_backend_code("jax", code) self.assertEqual(result.returncode, 0, result.stdout + result.stderr) From 4e07a885349d6da76dc0526d5dd3559896f00b04 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Sat, 13 Jun 2026 04:05:23 +0200 Subject: [PATCH 2/2] Trigger MegaLinter fix checks