Code
from BayesForge import bf
import jax.numpy as jnp
import numpy as np
# Setup device -----------------------------------------------
m = bf(platform='cpu')
# Load data β column args are mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m.load.sim_nested_effects(only_path=True)
m.data(data_path)
# Define model ------------------------------------------------
def model_nested(y, x, group_id, region_id, N_groups=20, N_regions=5):
sigma = m.dist.exponential(1, name='sigma')
# 1. Region level
mu_global = jnp.stack([m.dist.normal(5, 2, name='global_intercept'),
m.dist.normal(-1, 1, name='global_beta')])
sigma_reg = m.dist.exponential(1, shape=(2,), name='sigma_region')
corr_reg = m.dist.lkj(2, 2, name='corr_region')
cov_reg = jnp.diag(sigma_reg) @ corr_reg @ jnp.diag(sigma_reg)
region_effects = m.dist.multivariate_normal(
mu_global, cov_reg, shape=(N_regions,), name='region_effects'
)
# 2. Group level β parent mapping via JAX scatter (traceable under pmap)
group_to_region = jnp.zeros(N_groups, dtype=jnp.int32).at[group_id].set(region_id)
sigma_grp = m.dist.exponential(1, shape=(2,), name='sigma_group')
corr_grp = m.dist.lkj(2, 2, name='corr_group')
cov_grp = jnp.diag(sigma_grp) @ corr_grp @ jnp.diag(sigma_grp)
group_effects = m.dist.multivariate_normal(
region_effects[group_to_region], cov_grp, name='group_effects'
)
mu_est = group_effects[group_id, 0] + group_effects[group_id, 1] * x
m.dist.normal(mu_est, sigma, obs=y)
# Run sampler ------------------------------------------------
m.fit(model_nested, num_samples=1000, num_warmup=500, num_chains=1)
m.summary()/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
bf v 0.0.48 package loaded
jax.local_device_count 32
/home/sosa/work/3.12venv/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:
scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
0%| | 0/1500 [00:00<?, ?it/s]warmup: 0%| | 1/1500 [00:01<44:49, 1.79s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup: 6%|β | 86/1500 [00:01<00:22, 63.14it/s, 31 steps of size 2.88e-02. acc. prob=0.76]warmup: 11%|ββ | 169/1500 [00:01<00:09, 136.10it/s, 15 steps of size 3.97e-01. acc. prob=0.78]warmup: 19%|ββ | 281/1500 [00:02<00:04, 253.80it/s, 7 steps of size 2.81e-01. acc. prob=0.78] warmup: 27%|βββ | 405/1500 [00:02<00:02, 397.30it/s, 31 steps of size 1.89e-01. acc. prob=0.78]sample: 35%|ββββ | 528/1500 [00:02<00:01, 540.34it/s, 15 steps of size 2.64e-01. acc. prob=0.89]sample: 44%|βββββ | 653/1500 [00:02<00:01, 680.28it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 52%|ββββββ | 785/1500 [00:02<00:00, 820.88it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 61%|ββββββ | 910/1500 [00:02<00:00, 923.09it/s, 15 steps of size 2.64e-01. acc. prob=0.91]sample: 70%|βββββββ | 1046/1500 [00:02<00:00, 1032.28it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 78%|ββββββββ | 1171/1500 [00:02<00:00, 1087.30it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 87%|βββββββββ | 1305/1500 [00:02<00:00, 1156.08it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 96%|ββββββββββ| 1433/1500 [00:03<00:00, 1180.11it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 100%|ββββββββββ| 1500/1500 [00:03<00:00, 490.04it/s, 15 steps of size 2.64e-01. acc. prob=0.92]
/home/sosa/work/BF/BayesForge/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:
invalid value encountered in scalar divide
| mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| corr_group[0, 0] | 1.00 | 0.00 | 1.00 | 1.00 | 0.00 | 0.00 | 3000.00 | 3000.00 | NaN |
| corr_group[0, 1] | 0.42 | 0.23 | 0.06 | 0.77 | 0.01 | 0.01 | 446.31 | 458.13 | NaN |
| corr_group[1, 0] | 0.42 | 0.23 | 0.06 | 0.77 | 0.01 | 0.01 | 446.31 | 458.13 | NaN |
| corr_group[1, 1] | 1.00 | 0.00 | 1.00 | 1.00 | 0.00 | 0.00 | 996.33 | 963.91 | NaN |
| corr_region[0, 0] | 1.00 | 0.00 | 1.00 | 1.00 | 0.00 | 0.00 | 3000.00 | 3000.00 | NaN |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| sigma | 0.49 | 0.02 | 0.46 | 0.52 | 0.00 | 0.00 | 1167.56 | 640.41 | NaN |
| sigma_group[0] | 0.42 | 0.09 | 0.30 | 0.55 | 0.00 | 0.00 | 875.35 | 741.04 | NaN |
| sigma_group[1] | 0.25 | 0.06 | 0.16 | 0.35 | 0.00 | 0.00 | 342.37 | 524.02 | NaN |
| sigma_region[0] | 1.05 | 0.42 | 0.54 | 1.66 | 0.02 | 0.01 | 539.86 | 741.76 | NaN |
| sigma_region[1] | 0.47 | 0.24 | 0.15 | 0.75 | 0.01 | 0.01 | 531.66 | 569.97 | NaN |
65 rows Γ 9 columns