diff --git a/multiple_choice/multiple_choice.qmd b/multiple_choice/multiple_choice.qmd index 143a52b..c0e7717 100644 --- a/multiple_choice/multiple_choice.qmd +++ b/multiple_choice/multiple_choice.qmd @@ -139,11 +139,35 @@ item_id_0 = [chr(65 + j) for j in range(K)] # A, B, C, ... ### Base model logit_0 +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_0) dt = az.from_cmdstanpy(logit_0.sample(data=data, show_progress=False)) ``` +## Numpyro +```{.python} +from jax import random +import jax.numpy as jnp +import numpyro +from numpyro import distributions as dist +from numpyro.infer import NUTS, MCMC + +def logit_0(x, y=None): + a = numpyro.sample("a", dist.Normal(0, 1e6)) + b = numpyro.sample("b", dist.Normal(0, 1e6)) + with numpyro.plate("J", len(x)): + numpyro.sample("y", dist.Bernoulli(logits=a + b*x), obs=y) + +kernel = NUTS(logit_0) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False) +mcmc.run(random.PRNGKey(SEED), x=jnp.array(score), y=jnp.array(correct[:, 0])) +dt = az.from_numpyro(mcmc) +``` +::: + ```{python} #| label: fig-final_exams_1 @@ -160,6 +184,9 @@ az.summary(dt) ### Add priors +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_prior) data = { @@ -172,6 +199,30 @@ data = { dt = az.from_cmdstanpy(logit_prior.sample(data=data, show_progress=False)) ``` +## Numpyro + +```{.python} +def logit_prior(J, x, mu_a, sigma_a, mu_b, sigma_b, y=None): + x_adj = (x - x.mean()) / x.std() + a = numpyro.sample("a", dist.Normal(mu_a, sigma_a)) + b = numpyro.sample("b", dist.Normal(mu_b, sigma_b)) + with numpyro.plate("J", J): + numpyro.sample("y", dist.Bernoulli(logits=a + b*x_adj), obs=y) + +data = { + "J": J, + "x": jnp.array(score_adj), + "y": jnp.array(correct[:, 0]), + "mu_a": 0, "sigma_a": 5, + "mu_b": 0, "sigma_b": 5 +} +kernel = NUTS(logit_prior) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **data) +dt = az.from_numpyro(mcmc) +``` +::: + ```{python} #| label: fig-final_exams_2 plot_logit( @@ -206,6 +257,9 @@ plot_logit_grid( ) ``` +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} data_ = { "J": data["J"], @@ -217,6 +271,25 @@ data_ = { dt = az.from_cmdstanpy(logit_prior.sample(data=data_, show_progress=False)) ``` +## Numpyro + +```{.python} +data_ = { + "J": data["J"], + "x": jnp.array(score_adj), + "y": jnp.array(correct[:, 6]), + "mu_a": 0, "sigma_a": 5, + "mu_b": 0, "sigma_b": 5 + } + +kernel = NUTS(logit_prior) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **data_) +dt = az.from_numpyro(mcmc) +``` +::: + + ```{python} #| label: fig-final_exams_2_challenge @@ -277,10 +350,28 @@ _ = plot_logit_grid( ### Allow for guessing +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_guessing) ``` +## Numpyro + +```{.python} +import jax + +def logit_guessing(x, mu_a, sigma_a, mu_b, sigma_b, y=None): + x_adj = (x - x.mean()) / x.std() + a = numpyro.sample("a", dist.Normal(mu_a, sigma_a)) + b = numpyro.sample("b", dist.Normal(mu_b, sigma_b)) + with numpyro.plate("J", len(x)): + p = 0.25 + 0.75 * jax.nn.sigmoid(a + b * x_adj) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + ```{python} #| label: fig-final_exams_4 @@ -339,11 +430,49 @@ longdata_ = { ### Multilevel model +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_guessing_multilevel) dt_5 = az.from_cmdstanpy(logit_guessing_multilevel.sample(data=longdata_, show_progress=False)) ``` +## Numpyro + +```{.python} +import jax + +def logit_guessing_multilevel(x, student, item, N, J, K, + mu_mu_a, sigma_mu_a, mu_mu_b, sigma_mu_b, + mu_sigma_a, mu_sigma_b, y=None): + + x_adj = (x - x.mean()) / x.std() + mu_a = numpyro.sample("mu_a", dist.Normal(mu_mu_a, sigma_mu_a)) + mu_b = numpyro.sample("mu_b", dist.Normal(mu_mu_b, sigma_mu_b)) + sigma_a = numpyro.sample("sigma_a", dist.Exponential(1.0 / mu_sigma_a)) + sigma_b = numpyro.sample("sigma_b", dist.Exponential(1.0 / mu_sigma_b)) + with numpyro.plate("items", K): + a = numpyro.sample("a", dist.Normal(mu_a, sigma_a)) + b = numpyro.sample("b", dist.Normal(mu_b, sigma_b)) + with numpyro.plate("obs", N): + p = 0.25 + 0.75 * jax.nn.sigmoid(a[item] + b[item] * x_adj[student]) + numpyro.sample("y", dist.Bernoulli(p), obs=y) + +# Stan uses 1-indexed student/item — convert to 0-indexed for numpyro +longdata_["student"] = jnp.array(longdata_["student"]) - 1 +longdata_["item"] = jnp.array(longdata_["item"]) - 1 +longdata_["x"] = jnp.array(longdata_["x"]) +longdata_["y"] = jnp.array(longdata_["y"]) + +kernel = NUTS(logit_guessing_multilevel) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **longdata_) +dt_5 = az.from_numpyro(mcmc) +``` +::: + + ```{python} #| label: fig-final_exams_5 @@ -363,10 +492,36 @@ az.summary(dt_5, var_names=["mu_a", "sigma_a", "mu_b", "sigma_b"]) ### Multilevel model with correlation +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_guessing_multilevel_bivariate) ``` +## Numpyro + +```{.python} +def logit_guessing_multilevel_bivariate( + x, student, item, N, J, K, mu_mu_ab, sigma_mu_ab, mu_sigma_ab, y=None +): + x_adj = (x - x.mean()) / x.std() + mu_ab = numpyro.sample("mu_ab", dist.Normal(mu_mu_ab, sigma_mu_ab)) + sigma_ab = numpyro.sample("sigma_ab", dist.Exponential(1.0 / mu_sigma_ab)) + Omega_ab = numpyro.sample("Omega_ab", dist.LKJ(2, 1.0)) + with numpyro.plate("items", K): + e_ab = numpyro.sample("e_ab", dist.MultivariateNormal(jnp.zeros(2), Omega_ab)) + a = numpyro.deterministic("a", mu_ab[0] + sigma_ab[0] * e_ab[:, 0]) + b = numpyro.deterministic("b", mu_ab[1] + sigma_ab[1] * e_ab[:, 1]) + with numpyro.plate("obs", N): + p = 0.25 + 0.75 * jax.nn.sigmoid(a[item] + b[item] * x_adj[student]) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_6 @@ -390,12 +545,73 @@ plot_logit_grid_2( az.summary(dt_6, var_names=["mu_ab", "sigma_ab", "Omega_ab"]) ``` +## Numpyro + +```{.python} +longdata_6 = { + **longdata, + "mu_mu_ab": jnp.array([0., 0.]), + "sigma_mu_ab": jnp.array([5., 10.]), + "mu_sigma_ab": jnp.array([5., 10.]), +} +longdata_6["student"] = jnp.array(longdata_6["student"]) - 1 +longdata_6["item"] = jnp.array(longdata_6["item"]) - 1 +longdata_6["x"] = jnp.array(longdata_6["x"]) +longdata_6["y"] = jnp.array(longdata_6["y"]) + +kernel = NUTS(logit_guessing_multilevel_bivariate) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **longdata_6) +dt_6 = az.from_numpyro(mcmc) +plot_logit_grid_2( + dt_6, + "Multilevel model with correlation", + "Standardized exam score", + score_adj_jitt, + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +az.summary(dt_6, var_names=["mu_ab", "sigma_ab", "Omega_ab"]) +``` +::: + + ### Multilevel model with correlation using Cholesky +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_guessing_multilevel_bivariate_cholesky) ``` +## Numpyro + +```{.python} +def logit_guessing_multilevel_bivariate_cholesky( + x, student, item, N, J, K, mu_mu_ab, sigma_mu_ab, mu_sigma_ab, y=None + ): + x_adj = (x - x.mean()) / x.std() + mu_ab = numpyro.sample("mu_ab", dist.Normal(mu_mu_ab, sigma_mu_ab)) + sigma_ab = numpyro.sample("sigma_ab", dist.Exponential(1.0 / mu_sigma_ab)) + L_ab = numpyro.sample("L_ab", dist.LKJCholesky(2, 1.0)) + with numpyro.plate("items", K): + e_ab = numpyro.sample("e_ab", dist.MultivariateNormal(jnp.zeros(2), scale_tril=L_ab)) + a = numpyro.deterministic("a", mu_ab[0] + sigma_ab[0] * e_ab[:, 0]) + b = numpyro.deterministic("b", mu_ab[1] + sigma_ab[1] * e_ab[:, 1]) + Omega_ab = numpyro.deterministic("Omega_ab", L_ab @ L_ab.T) + with numpyro.plate("obs", N): + p = 0.25 + 0.75 * jax.nn.sigmoid(a[item] + b[item] * x_adj[student]) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_7 @@ -419,14 +635,73 @@ plot_logit_grid_2( az.summary(dt_7, var_names=["mu_ab", "sigma_ab", "Omega_ab"]) ``` +## Numpyro + +```{.python} +longdata_7 = { + **longdata, + "mu_mu_ab": jnp.array([0., 0.]), + "sigma_mu_ab": jnp.array([5., 10.]), + "mu_sigma_ab": jnp.array([5., 10.]), +} +longdata_7["student"] = jnp.array(longdata_7["student"]) - 1 +longdata_7["item"] = jnp.array(longdata_7["item"]) - 1 +longdata_7["x"] = jnp.array(longdata_7["x"]) +longdata_7["y"] = jnp.array(longdata_7["y"]) + +kernel = NUTS(logit_guessing_multilevel_bivariate_cholesky) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **longdata_7) +dt_7 = az.from_numpyro(mcmc) +plot_logit_grid_2( + dt_7, + "Multilevel model with correlation: Cholesky parameterization", + "Standardized exam score", + score_adj_jitt, + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +az.summary(dt_7, var_names=["mu_ab", "sigma_ab", "Omega_ab"]) +``` +::: + ## Item-response theory (IRT) models ### Item-response model +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(irt_guessing) ``` +## Numpyro + +```{.python} +def irt_guessing( + student, item, N, J, K, mu_mu_beta, sigma_mu_beta, mu_sigma_alpha, mu_sigma_beta, y=None +): + mu_beta = numpyro.sample("mu_beta", dist.Normal(mu_mu_beta, sigma_mu_beta)) + sigma_alpha = numpyro.sample("sigma_alpha", dist.Exponential(1.0 / mu_sigma_alpha)) + sigma_beta = numpyro.sample("sigma_beta", dist.Exponential(1.0 / mu_sigma_beta)) + with numpyro.plate("students", J): + alpha = numpyro.sample("alpha", dist.Normal(0, sigma_alpha)) + with numpyro.plate("items", K): + beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta)) + with numpyro.plate("obs", N): + p = 0.25 + 0.75 * jax.nn.sigmoid(alpha[student] - beta[item]) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + + +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_11 @@ -446,12 +721,69 @@ plot_irt( ) ``` +## Numpyro + +```{.python} +irt_data_11 = { + **longdata, + "mu_mu_beta": 0., "sigma_mu_beta": 5., + "mu_sigma_alpha": 5., "mu_sigma_beta": 5. +} +irt_data_11["student"] = jnp.array(irt_data_11["student"]) - 1 +irt_data_11["item"] = jnp.array(irt_data_11["item"]) - 1 +irt_data_11["y"] = jnp.array(irt_data_11["y"]) +irt_data_11.pop("x") + +kernel = NUTS(irt_guessing) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **irt_data_11) +dt_11 = az.from_numpyro(mcmc) +plot_irt( + dt_11, + "Item-response model", + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +``` +::: + ### Item-response model with discrimination parameters +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(irt_guessing_discrimination) ``` +## Numpyro + +```{.python} +def irt_guessing_discrimination( + student, item, N, J, K, mu_mu_beta, sigma_mu_beta, + mu_sigma_alpha, mu_sigma_beta, mu_sigma_gamma, y=None +): + mu_beta = numpyro.sample("mu_beta", dist.Normal(mu_mu_beta, sigma_mu_beta)) + sigma_alpha = numpyro.sample("sigma_alpha", dist.Exponential(1.0 / mu_sigma_alpha)) + sigma_beta = numpyro.sample("sigma_beta", dist.Exponential(1.0 / mu_sigma_beta)) + sigma_gamma = numpyro.sample("sigma_gamma", dist.Exponential(1.0 / mu_sigma_gamma)) + with numpyro.plate("students", J): + alpha = numpyro.sample("alpha", dist.Normal(0, sigma_alpha)) + with numpyro.plate("items", K): + beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta)) + gamma = numpyro.sample("gamma", dist.Normal(1, sigma_gamma)) + with numpyro.plate("obs", N): + p = 0.25 + 0.75 * jax.nn.sigmoid(gamma[item] * (alpha[student] - beta[item])) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_12 @@ -472,8 +804,51 @@ plot_irt( ) ``` +## Numpyro + +```{.python} +irt_data_12 = { + **longdata, + "mu_mu_beta": 0., "sigma_mu_beta": 5., + "mu_sigma_alpha": 5., "mu_sigma_beta": 5., + "mu_sigma_gamma": 0.5 +} +irt_data_12["student"] = jnp.array(irt_data_12["student"]) - 1 +irt_data_12["item"] = jnp.array(irt_data_12["item"]) - 1 +irt_data_12["y"] = jnp.array(irt_data_12["y"]) +irt_data_12.pop("x") + + +from numpyro.infer.reparam import LocScaleReparam + +reparam_config = { + "alpha": LocScaleReparam(centered=0), + "beta": LocScaleReparam(centered=0), + "gamma": LocScaleReparam(centered=0), +} +kernel = NUTS( + numpyro.handlers.reparam(irt_guessing_discrimination, config=reparam_config), + init_strategy=numpyro.infer.init_to_uniform(radius=2) +) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **irt_data_12) +dt_12 = az.from_numpyro(mcmc) +plot_irt( + dt_12, + "Item-response model with discrimination parameters", + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +``` +::: + ### Item-response model with discrimination parameters with init +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_13 @@ -494,6 +869,45 @@ plot_irt( ) ``` +## Numpyro + +```{.python} +irt_data_13 = { + **longdata, + "mu_mu_beta": 0., "sigma_mu_beta": 5., + "mu_sigma_alpha": 5., "mu_sigma_beta": 5., + "mu_sigma_gamma": 0.5 +} +irt_data_13["student"] = jnp.array(irt_data_13["student"]) - 1 +irt_data_13["item"] = jnp.array(irt_data_13["item"]) - 1 +irt_data_13["y"] = jnp.array(irt_data_13["y"]) +irt_data_13.pop("x") + +reparam_config = { + "alpha": LocScaleReparam(centered=0), + "beta": LocScaleReparam(centered=0), + "gamma": LocScaleReparam(centered=0), +} + +kernel = NUTS( + numpyro.handlers.reparam(irt_guessing_discrimination, config=reparam_config), + init_strategy=numpyro.infer.init_to_uniform(radius=0.1) +) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **irt_data_13) +dt_13 = az.from_numpyro(mcmc) +plot_irt( + dt_13, + "Item-response model with discrimination parameters", + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +``` +::: + IRT plots ```{python} @@ -657,10 +1071,25 @@ fig.suptitle("10 prior predictive simulations with a ~ normal(0, 50) and b ~ nor ## Breaking the model +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} print_stan(logit_guessing_uncentered) ``` +## Numpyro + +```{.python} +def logit_guessing_uncentered(x, mu_a, sigma_a, mu_b, sigma_b, y=None): + a = numpyro.sample("a", dist.Normal(mu_a, sigma_a)) + b = numpyro.sample("b", dist.Normal(mu_b, sigma_b)) + with numpyro.plate("J", len(x)): + p = 0.25 + 0.75 * jax.nn.sigmoid(a + b * x) + numpyro.sample("y", dist.Bernoulli(p), obs=y) +``` +::: + Simulate data ```{python} @@ -686,13 +1115,34 @@ break_data = { } ``` +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| results: hide break_1_fit = logit_guessing_uncentered.sample(data=break_data, show_progress=False) +dt_break_1 = az.from_cmdstanpy(break_1_fit) ``` +## Numpyro + +```{.python} +kernel = NUTS(logit_guessing_uncentered) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run( + random.PRNGKey(SEED), + x=jnp.array(x_break), + mu_a=0., sigma_a=1000., + mu_b=0., sigma_b=1000., + y=jnp.array(y_break), +) +dt_break_1 = az.from_numpyro(mcmc) +``` +::: + ```{python} -posterior = az.extract(az.from_cmdstanpy(break_1_fit), group="posterior") +posterior = az.extract(dt_break_1, group="posterior") a_break = posterior["a"].values.flatten() b_break = posterior["b"].values.flatten() n_sims = len(a_break) @@ -721,14 +1171,37 @@ ax.scatter(x_break, y_jitter, s=20, color="black", alpha=0.5) ``` + +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| results: hide break_2_fit = logit_guessing.sample(data=break_data, show_progress=False) +dt_break_2 = az.from_cmdstanpy(break_2_fit) print(break_2_fit) ``` +## Numpyro + +```{.python} +kernel = NUTS(logit_guessing) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run( + random.PRNGKey(SEED), + x=jnp.array(x_break), + mu_a=0., sigma_a=1000., + mu_b=0., sigma_b=1000., + y=jnp.array(y_break), +) +dt_break_2 = az.from_numpyro(mcmc) +mcmc.print_summary() +``` +::: + ```{python} -posterior = az.extract(az.from_cmdstanpy(break_2_fit), group="posterior") +posterior = az.extract(dt_break_2, group="posterior") a_break = posterior["a"].values.flatten() b_break = posterior["b"].values.flatten() n_sims = len(a_break) @@ -757,6 +1230,9 @@ ax.scatter(x_adj_break, y_jitter, s=20, color="black", alpha=0.5) ``` +::: {.panel-tabset group="ppl"} +## CmdStanPy + ```{python} #| label: fig-final_exams_break_1 @@ -780,6 +1256,38 @@ plot_logit_grid_2( ``` +## Numpyro + +```{.python} +longdata_break = { + **longdata, + "mu_mu_a": 0., "sigma_mu_a": 5., + "mu_mu_b": 0., "sigma_mu_b": 5., + "mu_sigma_a": 5., "mu_sigma_b": 5. +} +longdata_break["student"] = jnp.array(longdata_break["student"]) - 1 +longdata_break["item"] = jnp.array(longdata_break["item"]) - 1 +longdata_break["x"] = jnp.array(longdata_break["x"]) +longdata_break["y"] = jnp.array(longdata_break["y"]) + +kernel = NUTS(logit_guessing_multilevel) +mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=4, +progress_bar=False) +mcmc.run(random.PRNGKey(SEED), **longdata_break) +dt_break = az.from_numpyro(mcmc) +plot_logit_grid_2( + dt_break, + "Breaking the model", + "Exam score", + score_jitt, + longdata["y"], + longdata["item"], + item_id, + guessprob=0.25 +) +``` +::: + ## References {.unnumbered}