Skip to content

Commit bd74bc1

Browse files
authored
Merge pull request #111 from florisvb/kf-P0
addressing #110
2 parents e7026bd + 9d15aa5 commit bd74bc1

3 files changed

Lines changed: 24 additions & 20 deletions

File tree

pynumdiff/kalman_smooth/_kalman_smooth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _constant_derivative(x, P0, A, C, R, Q, forwardbackward):
8484
"""Helper for `constant_{velocity,acceleration,jerk}` functions, because there was a lot of
8585
repeated code.
8686
"""
87-
xhat0 = np.zeros(A.shape[0]); xhat0[0] = x[0]
87+
xhat0 = np.zeros(A.shape[0]); xhat0[0] = x[0] # See #110 for why this choice of xhat0
8888
xhat_smooth = _RTS_smooth(xhat0, P0, x, A, C, Q, R) # noisy x are the "y" in Kalman-land
8989
x_hat_forward = xhat_smooth[:, 0] # first dimension is time, so slice first element at all times
9090
dxdt_hat_forward = xhat_smooth[:, 1]
@@ -135,7 +135,7 @@ def constant_velocity(x, dt, params=None, options=None, r=None, q=None, forwardb
135135
C = np.array([[1, 0]]) # we measure only y = noisy x
136136
R = np.array([[r]])
137137
Q = np.array([[1e-16, 0], [0, q]]) # uncertainty is around the velocity
138-
P0 = np.array(100*np.eye(2)) # Why is this one magnitude 100 vs the other ones being 10?
138+
P0 = np.array(100*np.eye(2)) # See #110 for why this choice of P0
139139

140140
return _constant_derivative(x, P0, A, C, R, Q, forwardbackward)
141141

@@ -174,7 +174,7 @@ def constant_acceleration(x, dt, params=None, options=None, r=None, q=None, forw
174174
Q = np.array([[1e-16, 0, 0],
175175
[0, 1e-16, 0],
176176
[0, 0, q]]) # uncertainty is around the acceleration
177-
P0 = np.array(10*np.eye(3))
177+
P0 = np.array(100*np.eye(3)) # See #110 for why this choice of P0
178178

179179
return _constant_derivative(x, P0, A, C, R, Q, forwardbackward)
180180

@@ -215,7 +215,7 @@ def constant_jerk(x, dt, params=None, options=None, r=None, q=None, forwardbackw
215215
[0, 1e-16, 0, 0],
216216
[0, 0, 1e-16, 0],
217217
[0, 0, 0, q]]) # uncertainty is around the jerk
218-
P0 = np.array(10*np.eye(4))
218+
P0 = np.array(100*np.eye(4)) # See #110 for why this choice of P0
219219

220220
return _constant_derivative(x, P0, A, C, R, Q, forwardbackward)
221221

pynumdiff/tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from matplotlib import pyplot
44
from collections import defaultdict
55

6-
def pytest_addoption(parser): parser.addoption("--plot", action="store_true", default=False)
6+
def pytest_addoption(parser):
7+
parser.addoption("--plot", action="store_true", default=False) # whether to show plots
8+
parser.addoption("--bounds", action="store_true", default=False) # whether to print error bounds
79

810
@pytest.fixture(scope="session", autouse=True)
911
def store_plots(request):

pynumdiff/tests/test_diff_methods.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def iterated_first_order(*args, **kwargs): return first_order(*args, **kwargs)
4848
(constant_jerk, {'r':1e-4, 'q':10}), (constant_jerk, [1e-4, 10]),
4949
# TODO (known_dynamics), but presently it doesn't calculate a derivative
5050
]
51+
diff_methods_and_params = [(constant_jerk, {'r':1e-4, 'q':10})]
5152

5253
# All the testing methodology follows the exact same pattern; the only thing that changes is the
5354
# closeness to the right answer various methods achieve with the given parameterizations. So index a
@@ -133,14 +134,14 @@ def iterated_first_order(*args, **kwargs): return first_order(*args, **kwargs)
133134
[(1, 1), (2, 2), (1, 1), (2, 2)],
134135
[(1, 1), (3, 3), (1, 1), (3, 3)]],
135136
constant_acceleration: [[(-25, -25), (-25, -25), (0, 0), (1, 0)],
136-
[(-4, -4), (-2, -3), (0, 0), (1, 0)],
137-
[(-3, -3), (-1, -2), (0, 0), (1, 0)],
137+
[(-5, -5), (-3, -3), (0, 0), (1, 0)],
138+
[(-4, -4), (-2, -2), (0, 0), (1, 0)],
138139
[(0, -1), (1, 0), (0, 0), (1, 0)],
139140
[(1, 1), (3, 2), (1, 1), (3, 2)],
140141
[(1, 1), (3, 3), (1, 1), (3, 3)]],
141142
constant_jerk: [[(-25, -25), (-25, -25), (0, 0), (1, 0)],
142-
[(-4, -4), (-2, -3), (0, 0), (1, 0)],
143-
[(-3, -3), (-1, -2), (0, 0), (1, 0)],
143+
[(-5, -5), (-3, -3), (0, 0), (1, 0)],
144+
[(-4, -4), (-2, -2), (0, 0), (1, 0)],
144145
[(-1, -2), (1, 0), (0, 0), (1, 0)],
145146
[(1, 0), (2, 2), (1, 0), (2, 2)],
146147
[(1, 1), (3, 3), (1, 1), (3, 3)]]
@@ -174,18 +175,19 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
174175
else diff_method(x_noisy, dt, params, options)
175176

176177
# check x_hat and x_hat_noisy are close to x and that dxdt_hat and dxdt_hat_noisy are close to dxdt
177-
#print("]\n[", end="")
178+
if request.config.getoption("--bounds"): print("]\n[", end="")
178179
for j,(a,b) in enumerate([(x,x_hat), (dxdt,dxdt_hat), (x,x_hat_noisy), (dxdt,dxdt_hat_noisy)]):
179-
# l2_error = np.linalg.norm(a - b)
180-
# linf_error = np.max(np.abs(a - b))
181-
# print(f"({l2_error},{linf_error})", end=", ")
182-
# print(f"({int(np.ceil(np.log10(l2_error))) if l2_error > 0 else -25}, {int(np.ceil(np.log10(linf_error))) if linf_error > 0 else -25})", end=", ")
183-
184-
log_l2_bound, log_linf_bound = error_bounds[diff_method][i][j]
185-
assert np.linalg.norm(a - b) < 10**log_l2_bound
186-
assert np.max(np.abs(a - b)) < 10**log_linf_bound
187-
if 0 < np.linalg.norm(a - b) < 10**(log_l2_bound - 1) or 0 < np.max(np.abs(a - b)) < 10**(log_linf_bound - 1):
188-
print(f"Improvement detected for method {diff_method.__name__}")
180+
if request.config.getoption("--bounds"):
181+
l2_error = np.linalg.norm(a - b)
182+
linf_error = np.max(np.abs(a - b))
183+
#print(f"({l2_error},{linf_error})", end=", ")
184+
print(f"({int(np.ceil(np.log10(l2_error))) if l2_error > 0 else -25}, {int(np.ceil(np.log10(linf_error))) if linf_error > 0 else -25})", end=", ")
185+
else:
186+
log_l2_bound, log_linf_bound = error_bounds[diff_method][i][j]
187+
assert np.linalg.norm(a - b) < 10**log_l2_bound
188+
assert np.max(np.abs(a - b)) < 10**log_linf_bound
189+
if 0 < np.linalg.norm(a - b) < 10**(log_l2_bound - 1) or 0 < np.max(np.abs(a - b)) < 10**(log_linf_bound - 1):
190+
print(f"Improvement detected for method {diff_method.__name__}")
189191

190192
if request.config.getoption("--plot") and not isinstance(params, list): # Get the plot flag from pytest configuration
191193
fig, axes = request.config.plots[diff_method] # get the appropriate plot, set up by the store_plots fixture in conftest.py

0 commit comments

Comments
 (0)