Skip to content

Commit 06a6549

Browse files
committed
added notebook to showcase robustdiff's raw output versus rtsdiff
1 parent 241321b commit 06a6549

2 files changed

Lines changed: 190 additions & 5 deletions

File tree

examples/5_robust_outliers_demo.ipynb

Lines changed: 187 additions & 0 deletions
Large diffs are not rendered by default.

pynumdiff/utils/evaluate.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth, xlim=None, show_error=True
2626
if xlim is None:
2727
xlim = [t[0], t[-1]]
2828

29-
fig, axes = plt.subplots(1, 2, figsize=(18, 6))
29+
fig, axes = plt.subplots(1, 2, figsize=(18, 6), constrained_layout=True)
3030

3131
axes[0].plot(t, x_truth, '--', color='black', linewidth=3, label=r"true $x$")
3232
axes[0].plot(t, x, '.', color='blue', zorder=-100, markersize=markersize, label=r"noisy data")
@@ -37,7 +37,6 @@ def plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth, xlim=None, show_error=True
3737
axes[0].tick_params(axis='x', labelsize=15)
3838
axes[0].tick_params(axis='y', labelsize=15)
3939
axes[0].legend(loc='lower right', fontsize=12)
40-
axes[0].set_rasterization_zorder(0)
4140

4241
axes[1].plot(t, dxdt_truth, '--', color='black', linewidth=3, label=r"true $\frac{dx}{dt}$")
4342
axes[1].plot(t, dxdt_hat, color='red', label=r"est. $\hat{\frac{dx}{dt}}$")
@@ -47,15 +46,14 @@ def plot(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth, xlim=None, show_error=True
4746
axes[1].tick_params(axis='x', labelsize=15)
4847
axes[1].tick_params(axis='y', labelsize=15)
4948
axes[1].legend(loc='lower right', fontsize=12)
50-
axes[1].set_rasterization_zorder(0)
51-
52-
fig.tight_layout()
5349

5450
if show_error:
5551
_, _, rms_dxdt = rmse(x, dt, x_hat, dxdt_hat, x_truth, dxdt_truth)
5652
R_sqr = error_correlation(dxdt_hat, dxdt_truth)
5753
axes[1].text(0.05, 0.95, f"RMSE = {rms_dxdt:.2f}\n$R^2$ = {R_sqr:.2g}",
5854
transform=axes[1].transAxes, fontsize=15, verticalalignment='top')
55+
56+
return fig, axes
5957

6058

6159
def plot_comparison(dt, dxdt_truth, dxdt_hat1, title1, dxdt_hat2, title2, dxdt_hat3, title3):

0 commit comments

Comments
 (0)