Skip to content

Commit b4031b6

Browse files
authored
fix: correcting decomposition logic (#59)
1 parent c65fb9a commit b4031b6

3 files changed

Lines changed: 56 additions & 6 deletions

File tree

panel/simdec_app.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def load_data(text_fname):
7070
return pd.read_csv(io.BytesIO(raw))
7171
except Exception:
7272
pn.state.notifications.error(GENERIC_ERROR_MSG, duration=0)
73+
return pd.read_csv(DEFAULT_STRESS_CSV)
7374

7475

7576
@pn.cache
@@ -183,11 +184,11 @@ def explained_variance_80(sensitivity_indices_table):
183184
si_values = df["Value"].tolist()[1:]
184185
input_names = df["Inputs"].tolist()[1:]
185186

186-
# Ensuring explained variance is at least 80% of the total
187-
target = 0.8 * sum(si_values)
188-
pos_80 = bisect.bisect_right(np.cumsum(si_values), target)
189-
190-
return input_names[: pos_80 + 1]
187+
# Find the variables needed to reach 80% of explained variance
188+
total = sum(si_values)
189+
pos = bisect.bisect_left(np.cumsum(si_values), 0.8 * total)
190+
n_vars = min(pos + 1, 4)
191+
return input_names[:n_vars]
191192

192193

193194
@pn.cache

src/simdec/decomposition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def decomposition(
138138
n_var_dec = sensitivity_indices.size
139139

140140
n_var_dec = max(1, n_var_dec) # keep at least one variable
141-
n_var_dec = min(5, n_var_dec) # use at most 5 variables
141+
n_var_dec = min(4, n_var_dec) # use at most 4 variables
142142
else:
143143
n_var_dec = inputs.shape[1]
144144

tests/test_decomposition.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,52 @@ def test_decomposition():
2424
assert res.states == [2, 2, 2, 2]
2525
assert res.statistic.shape == (2, 2, 2, 2)
2626
npt.assert_allclose(res.bins.describe().T["mean"], res.statistic.flatten())
27+
28+
29+
def test_auto_ordering_single_dominant_variable():
30+
fname = path_data / "stress.csv"
31+
data = pd.read_csv(fname)
32+
output_name, *v_names = list(data.columns)
33+
inputs, output = data[v_names], data[output_name]
34+
35+
si = np.array([0.90, 0.05, 0.03, 0.02])
36+
res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si)
37+
assert len(res.var_names) == 1
38+
39+
40+
def test_auto_ordering_two_variables_cross_threshold():
41+
fname = path_data / "stress.csv"
42+
data = pd.read_csv(fname)
43+
output_name, *v_names = list(data.columns)
44+
inputs, output = data[v_names], data[output_name]
45+
46+
# sum = 1.0, cumsum = [0.75, 0.81, ...] -> crosses 0.8 after 2nd variable
47+
si = np.array([0.75, 0.06, 0.10, 0.09])
48+
res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si)
49+
assert len(res.var_names) == 2
50+
51+
52+
def test_auto_ordering_cap_at_four():
53+
"""Even if more than 4 variables are needed to reach 0.8, cap at 4."""
54+
fname = path_data / "stress.csv"
55+
data = pd.read_csv(fname)
56+
output_name, *v_names = list(data.columns)
57+
inputs, output = data[v_names], data[output_name]
58+
59+
# sum = 1.0, each variable contributes equally -> need all 4 to reach 0.8
60+
si = np.array([0.25, 0.25, 0.25, 0.25])
61+
res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si)
62+
assert len(res.var_names) == 4
63+
64+
65+
def test_auto_ordering_si_not_summing_to_one():
66+
"""Threshold is relative to sum(si), not hardcoded 1.0."""
67+
fname = path_data / "stress.csv"
68+
data = pd.read_csv(fname)
69+
output_name, *v_names = list(data.columns)
70+
inputs, output = data[v_names], data[output_name]
71+
72+
# sum = 2.0, 0.8 * 2.0 = 1.6, cumsum = [1.8, ...] -> crosses after 1st variable
73+
si = np.array([1.80, 0.10, 0.05, 0.05])
74+
res = sd.decomposition(inputs=inputs, output=output, sensitivity_indices=si)
75+
assert len(res.var_names) == 1

0 commit comments

Comments
 (0)