🚧Nested Varying Effects

Hierarchical Models
Regression
Modeling hierarchical data with multiple levels of nesting (e.g., individuals within groups within regions).

General Principles

To model hierarchies where groups are themselves nested within larger units (e.g., students within classes, which are within schools), we use Nested Varying Effects. This allows the model to share information across different levels of the hierarchy, improving estimates through multilevel pooling.

Note

Note
  • We have the same considerations as for Varying Intercepts and Varying Slopes.

  • In a nested model, a parameter at one level (e.g., a group-level intercept) becomes the mean for the distribution of parameters at a lower level (e.g., individual-level intercepts).

  • This structure is often represented using multiple indices: j for the group and k(j) for the super-group (e.g., region) that group j belongs to.

  • To capture the correlation between multiple varying effects (e.g., intercept and slope) at any level of the hierarchy, we use a Multivariate Normal (MVN) distribution.

  • Non-centered parameterization is highly recommended for nested models to avoid β€œfunnel” geometries that can hinder MCMC sampling.

Example

Below is an example of a nested model with both varying intercepts and varying slopes. We model the outcome y for individuals in groups nested within regions. The relationship between x and y varies at both levels.

Code
from BI import bi
import jax.numpy as jnp
import numpy as np

# Setup device -----------------------------------------------
m = bi(platform='cpu')

# Load data β€” column args are mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m.load.sim_nested_effects(only_path=True)
m.data(data_path)

# Define model ------------------------------------------------
def model_nested(y, x, group_id, region_id, N_groups=20, N_regions=5):
    sigma = m.dist.exponential(1, name='sigma')

    # 1. Region level
    mu_global = jnp.stack([m.dist.normal(5, 2,  name='global_intercept'),
                            m.dist.normal(-1, 1, name='global_beta')])
    sigma_reg  = m.dist.exponential(1, shape=(2,), name='sigma_region')
    corr_reg   = m.dist.lkj(2, 2, name='corr_region')
    cov_reg    = jnp.diag(sigma_reg) @ corr_reg @ jnp.diag(sigma_reg)
    region_effects = m.dist.multivariate_normal(
        mu_global, cov_reg, shape=(N_regions,), name='region_effects'
    )

    # 2. Group level β€” parent mapping via JAX scatter (traceable under pmap)
    group_to_region = jnp.zeros(N_groups, dtype=jnp.int32).at[group_id].set(region_id)
    sigma_grp  = m.dist.exponential(1, shape=(2,), name='sigma_group')
    corr_grp   = m.dist.lkj(2, 2, name='corr_group')
    cov_grp    = jnp.diag(sigma_grp) @ corr_grp @ jnp.diag(sigma_grp)
    group_effects = m.dist.multivariate_normal(
        region_effects[group_to_region], cov_grp, name='group_effects'
    )

    mu_est = group_effects[group_id, 0] + group_effects[group_id, 1] * x
    m.dist.normal(mu_est, sigma, obs=y)

# Run sampler ------------------------------------------------
m.fit(model_nested, num_samples=1000, num_warmup=500, num_chains=1)
m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:

scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
jax.local_device_count 32
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   0%|          | 1/1500 [00:01<36:37,  1.47s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   7%|β–‹         | 106/1500 [00:01<00:14, 93.78it/s, 31 steps of size 3.49e-01. acc. prob=0.77]warmup:  15%|β–ˆβ–        | 221/1500 [00:01<00:06, 212.50it/s, 31 steps of size 2.65e-01. acc. prob=0.78]warmup:  23%|β–ˆβ–ˆβ–Ž       | 346/1500 [00:01<00:03, 355.83it/s, 15 steps of size 5.91e-01. acc. prob=0.78]warmup:  33%|β–ˆβ–ˆβ–ˆβ–Ž      | 490/1500 [00:01<00:01, 534.16it/s, 7 steps of size 1.21e-01. acc. prob=0.78] sample:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 631/1500 [00:01<00:01, 700.43it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 765/1500 [00:02<00:00, 836.79it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 900/1500 [00:02<00:00, 956.54it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 1030/1500 [00:02<00:00, 1036.58it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 1159/1500 [00:02<00:00, 1102.94it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 1288/1500 [00:02<00:00, 1115.03it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample:  94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 1413/1500 [00:02<00:00, 1142.00it/s, 15 steps of size 2.64e-01. acc. prob=0.92]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:02<00:00, 564.47it/s, 15 steps of size 2.64e-01. acc. prob=0.92] 
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:

invalid value encountered in scalar divide
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
corr_group[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
corr_group[0, 1] 0.42 0.23 0.06 0.77 0.01 0.01 446.31 458.13 NaN
corr_group[1, 0] 0.42 0.23 0.06 0.77 0.01 0.01 446.31 458.13 NaN
corr_group[1, 1] 1.00 0.00 1.00 1.00 0.00 0.00 996.33 963.91 NaN
corr_region[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
... ... ... ... ... ... ... ... ... ...
sigma 0.49 0.02 0.46 0.52 0.00 0.00 1167.56 640.41 NaN
sigma_group[0] 0.42 0.09 0.30 0.55 0.00 0.00 875.35 741.04 NaN
sigma_group[1] 0.25 0.06 0.16 0.35 0.00 0.00 342.37 524.02 NaN
sigma_region[0] 1.05 0.42 0.54 1.66 0.02 0.01 539.86 741.76 NaN
sigma_region[1] 0.47 0.24 0.15 0.75 0.01 0.01 531.66 569.97 NaN

65 rows Γ— 9 columns

Code
# Data and m object carry over from the Centered tab above.

def model_nested_noncentered(y, x, group_id, region_id, N_groups=20, N_regions=5):
    sigma = m.dist.exponential(1, name='sigma')

    # 1. Region level
    mu_global    = jnp.stack([m.dist.normal(5, 2,  name='global_intercept'),
                               m.dist.normal(-1, 1, name='global_beta')])
    sigma_region = m.dist.exponential(1, shape=(2,), name='sigma_region')
    L_region     = m.dist.lkj_cholesky(2, 2, name='L_region')
    z_region     = m.dist.normal(0, 1, name='z_region', shape=(2, N_regions))
    region_effects = mu_global + ((sigma_region[..., None] * L_region) @ z_region).T

    # 2. Group level β€” parent mapping via JAX scatter (traceable under pmap)
    pid          = jnp.zeros(N_groups, dtype=jnp.int32).at[group_id].set(region_id)
    sigma_group  = m.dist.exponential(1, shape=(2,), name='sigma_group')
    L_group      = m.dist.lkj_cholesky(2, 2, name='L_group')
    z_group      = m.dist.normal(0, 1, name='z_group', shape=(2, N_groups))
    group_effects = region_effects[pid] + ((sigma_group[..., None] * L_group) @ z_group).T

    mu_est = group_effects[group_id, 0] + group_effects[group_id, 1] * x
    m.dist.normal(mu_est, sigma, obs=y)

m.fit(model_nested_noncentered, num_samples=1000, num_warmup=500, num_chains=1)
m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:

scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.

  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   0%|          | 1/1500 [00:01<25:55,  1.04s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   3%|β–Ž         | 51/1500 [00:01<00:23, 61.06it/s, 511 steps of size 1.87e-02. acc. prob=0.75]warmup:   7%|β–‹         | 101/1500 [00:01<00:11, 125.30it/s, 255 steps of size 3.20e-01. acc. prob=0.77]warmup:  10%|β–‰         | 149/1500 [00:01<00:07, 186.11it/s, 255 steps of size 5.46e-02. acc. prob=0.77]warmup:  13%|β–ˆβ–Ž        | 193/1500 [00:01<00:05, 236.22it/s, 255 steps of size 6.08e-02. acc. prob=0.77]warmup:  17%|β–ˆβ–‹        | 254/1500 [00:01<00:03, 317.61it/s, 511 steps of size 1.23e-02. acc. prob=0.77]warmup:  20%|β–ˆβ–ˆ        | 303/1500 [00:01<00:03, 356.29it/s, 255 steps of size 3.94e-02. acc. prob=0.78]warmup:  25%|β–ˆβ–ˆβ–       | 370/1500 [00:01<00:02, 434.86it/s, 255 steps of size 4.21e-02. acc. prob=0.78]warmup:  30%|β–ˆβ–ˆβ–‰       | 445/1500 [00:01<00:02, 517.82it/s, 63 steps of size 4.01e-02. acc. prob=0.78] sample:  34%|β–ˆβ–ˆβ–ˆβ–Ž      | 506/1500 [00:01<00:01, 532.19it/s, 127 steps of size 4.82e-02. acc. prob=0.94]sample:  39%|β–ˆβ–ˆβ–ˆβ–‰      | 587/1500 [00:02<00:01, 608.76it/s, 127 steps of size 4.82e-02. acc. prob=0.86]sample:  44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 654/1500 [00:02<00:01, 623.75it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample:  48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 724/1500 [00:02<00:01, 645.20it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 796/1500 [00:02<00:01, 665.20it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample:  58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 873/1500 [00:02<00:00, 694.84it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 946/1500 [00:02<00:00, 704.13it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 1018/1500 [00:02<00:00, 708.45it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 1090/1500 [00:02<00:00, 711.56it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 1162/1500 [00:02<00:00, 713.89it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 1236/1500 [00:02<00:00, 718.52it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 1310/1500 [00:03<00:00, 723.00it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 1385/1500 [00:03<00:00, 729.04it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 1459/1500 [00:03<00:00, 714.44it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:03<00:00, 448.47it/s, 127 steps of size 4.82e-02. acc. prob=0.89]
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:

invalid value encountered in scalar divide
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
L_group[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
L_group[0, 1] 0.00 0.00 0.00 0.00 0.00 0.00 3000.00 3000.00 NaN
L_group[1, 0] 0.42 0.23 0.03 0.75 0.01 0.01 584.84 675.29 NaN
L_group[1, 1] 0.87 0.11 0.72 1.00 0.00 0.00 581.67 675.29 NaN
L_region[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
... ... ... ... ... ... ... ... ... ...
z_region[1, 0] -0.71 0.66 -1.71 0.38 0.02 0.02 831.50 720.19 NaN
z_region[1, 1] 1.05 0.65 0.09 2.11 0.02 0.02 934.02 733.74 NaN
z_region[1, 2] -0.58 0.66 -1.62 0.42 0.03 0.02 661.19 724.94 NaN
z_region[1, 3] 0.39 0.79 -0.75 1.73 0.02 0.02 1060.03 744.14 NaN
z_region[1, 4] -0.02 0.70 -1.10 1.08 0.02 0.02 884.18 604.35 NaN

65 rows Γ— 9 columns

Code
from BI import bi
import jax.numpy as jnp

# Setup device -----------------------------------------------
m = bi(platform='cpu')

# Load data β€” column args are mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m.load.sim_nested_effects(only_path=True)
m.data(data_path)

# group_ids: one obs-level index array per level, top to bottom.
# Parent structure is derived automatically from these arrays.
def model_nested_builtin(y, x, group_id, region_id, N_groups=20, N_regions=5):
    sigma = m.dist.exponential(1, name='sigma')

    a_g_est, b_g_est = m.effects.nested_varying_effects(
        N_vars=2,
        names=["region", "group"],
        N_groups=[N_regions, N_groups],
        group_ids=[region_id, group_id],
        centered=False,
    )

    mu_est = a_g_est + b_g_est * x
    m.dist.normal(mu_est, sigma, obs=y)

m.fit(model_nested_builtin, num_samples=1000, num_warmup=500, num_chains=1)
m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:

scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
jax.local_device_count 32
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   0%|          | 1/1500 [00:01<26:20,  1.05s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   4%|β–Ž         | 55/1500 [00:01<00:22, 65.25it/s, 127 steps of size 3.98e-02. acc. prob=0.76]warmup:   7%|β–‹         | 109/1500 [00:01<00:10, 133.41it/s, 1023 steps of size 1.87e-02. acc. prob=0.76]warmup:  10%|β–ˆ         | 154/1500 [00:01<00:07, 186.77it/s, 511 steps of size 2.16e-02. acc. prob=0.77] warmup:  14%|β–ˆβ–        | 209/1500 [00:01<00:05, 256.79it/s, 127 steps of size 4.99e-02. acc. prob=0.78]warmup:  17%|β–ˆβ–‹        | 261/1500 [00:01<00:03, 311.32it/s, 1023 steps of size 2.03e-02. acc. prob=0.78]warmup:  22%|β–ˆβ–ˆβ–       | 324/1500 [00:01<00:03, 385.07it/s, 255 steps of size 5.45e-02. acc. prob=0.78] warmup:  26%|β–ˆβ–ˆβ–Œ       | 392/1500 [00:01<00:02, 458.89it/s, 255 steps of size 3.35e-02. acc. prob=0.78]warmup:  31%|β–ˆβ–ˆβ–ˆβ–      | 471/1500 [00:01<00:01, 537.81it/s, 1023 steps of size 1.32e-02. acc. prob=0.78]sample:  37%|β–ˆβ–ˆβ–ˆβ–‹      | 548/1500 [00:01<00:01, 599.05it/s, 127 steps of size 4.82e-02. acc. prob=0.84] sample:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 623/1500 [00:02<00:01, 637.72it/s, 127 steps of size 4.82e-02. acc. prob=0.87]sample:  46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 692/1500 [00:02<00:01, 648.48it/s, 63 steps of size 4.82e-02. acc. prob=0.88] sample:  51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 761/1500 [00:02<00:01, 625.23it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample:  56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 842/1500 [00:02<00:00, 676.03it/s, 127 steps of size 4.82e-02. acc. prob=0.88]sample:  61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 919/1500 [00:02<00:00, 702.27it/s, 63 steps of size 4.82e-02. acc. prob=0.89] sample:  66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 995/1500 [00:02<00:00, 716.91it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 1072/1500 [00:02<00:00, 730.73it/s, 63 steps of size 4.82e-02. acc. prob=0.89]sample:  77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 1154/1500 [00:02<00:00, 753.46it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 1231/1500 [00:02<00:00, 751.36it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 1308/1500 [00:03<00:00, 756.10it/s, 63 steps of size 4.82e-02. acc. prob=0.89] sample:  92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 1384/1500 [00:03<00:00, 749.82it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample:  97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 1462/1500 [00:03<00:00, 758.59it/s, 127 steps of size 4.82e-02. acc. prob=0.89]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:03<00:00, 458.11it/s, 127 steps of size 4.82e-02. acc. prob=0.89]
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:

invalid value encountered in scalar divide
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
L_corr_group[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
L_corr_group[0, 1] 0.00 0.00 0.00 0.00 0.00 0.00 3000.00 3000.00 NaN
L_corr_group[1, 0] 0.42 0.23 0.03 0.75 0.01 0.01 584.84 675.29 NaN
L_corr_group[1, 1] 0.87 0.11 0.72 1.00 0.00 0.00 581.67 675.29 NaN
L_corr_region[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
... ... ... ... ... ... ... ... ... ...
z_region[1, 0] -0.71 0.66 -1.71 0.38 0.02 0.02 831.50 720.19 NaN
z_region[1, 1] 1.05 0.65 0.09 2.11 0.02 0.02 934.02 733.74 NaN
z_region[1, 2] -0.58 0.66 -1.62 0.42 0.03 0.02 661.19 724.94 NaN
z_region[1, 3] 0.39 0.79 -0.75 1.73 0.02 0.02 1060.03 744.14 NaN
z_region[1, 4] -0.02 0.70 -1.10 1.08 0.02 0.02 884.18 604.35 NaN

65 rows Γ— 9 columns

Code
from BI import bi
import jax.numpy as jnp

# Setup device -----------------------------------------------
m = bi(platform='cpu')

# Load data β€” column args are mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m.load.sim_nested_effects(only_path=True)
m.data(data_path)

def model_nested_builtin_centered(y, x, group_id, region_id, N_groups=20, N_regions=5):
    sigma = m.dist.exponential(1, name='sigma')

    a_g_est, b_g_est = m.effects.nested_varying_effects(
        N_vars=2,
        names=["region", "group"],
        N_groups=[N_regions, N_groups],
        group_ids=[region_id, group_id],
        centered=True,
    )

    mu_est = a_g_est + b_g_est * x
    m.dist.normal(mu_est, sigma, obs=y)

m.fit(model_nested_builtin_centered, num_samples=1000, num_warmup=500, num_chains=1)
m.summary()
/home/sosa/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py:108: FutureWarning:

scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
jax.local_device_count 32
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   0%|          | 1/1500 [00:01<36:59,  1.48s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   7%|β–‹         | 108/1500 [00:01<00:14, 94.75it/s, 7 steps of size 7.43e-01. acc. prob=0.77]warmup:  15%|β–ˆβ–Œ        | 225/1500 [00:01<00:05, 214.58it/s, 15 steps of size 2.31e-01. acc. prob=0.78]warmup:  23%|β–ˆβ–ˆβ–Ž       | 352/1500 [00:01<00:03, 359.60it/s, 7 steps of size 4.14e-01. acc. prob=0.78] warmup:  33%|β–ˆβ–ˆβ–ˆβ–Ž      | 492/1500 [00:01<00:01, 529.85it/s, 31 steps of size 2.48e-01. acc. prob=0.79]sample:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 631/1500 [00:01<00:01, 692.48it/s, 15 steps of size 2.74e-01. acc. prob=0.92]sample:  52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 774/1500 [00:02<00:00, 848.73it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample:  61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 908/1500 [00:02<00:00, 961.98it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample:  70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 1044/1500 [00:02<00:00, 1060.53it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample:  79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 1188/1500 [00:02<00:00, 1158.17it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample:  89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 1328/1500 [00:02<00:00, 1222.24it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample:  98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 1470/1500 [00:02<00:00, 1275.85it/s, 15 steps of size 2.74e-01. acc. prob=0.91]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:02<00:00, 575.18it/s, 15 steps of size 2.74e-01. acc. prob=0.91] 
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:

invalid value encountered in scalar divide
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
corr_group[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
corr_group[0, 1] 0.40 0.24 0.08 0.80 0.01 0.01 605.50 556.22 NaN
corr_group[1, 0] 0.40 0.24 0.08 0.80 0.01 0.01 605.50 556.22 NaN
corr_group[1, 1] 1.00 0.00 1.00 1.00 0.00 0.00 1023.28 1031.23 NaN
corr_region[0, 0] 1.00 0.00 1.00 1.00 0.00 0.00 3000.00 3000.00 NaN
... ... ... ... ... ... ... ... ... ...
sigma 0.49 0.02 0.46 0.53 0.00 0.00 1073.98 760.39 NaN
sigma_group[0] 0.42 0.08 0.29 0.54 0.00 0.00 919.03 737.85 NaN
sigma_group[1] 0.25 0.07 0.15 0.34 0.00 0.00 457.05 510.07 NaN
sigma_region[0] 1.04 0.41 0.47 1.56 0.01 0.01 758.53 766.58 NaN
sigma_region[1] 0.46 0.24 0.17 0.75 0.01 0.01 542.06 631.21 NaN

65 rows Γ— 9 columns

using BayesianInference

# Setup device ------------------------------------------------
m = importBI(platform="cpu")

# Load data β€” column args mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m.load.sim_nested_effects(only_path=true)
m.data(data_path)

# Define model ------------------------------------------------
@BI function model(y, x, group_id, region_id, N_groups=20, N_regions=5)
    sigma = m.dist.exponential(1, name="sigma")

    a_g_est, b_g_est = m.effects.nested_varying_effects(
        N_vars    = 2,
        names     = ["region", "group"],
        N_groups  = [N_regions, N_groups],
        group_ids = [region_id, group_id],
        centered  = false,
    )

    mu = a_g_est .+ b_g_est .* x
    m.dist.normal(mu, sigma, obs=y)
end

# Run mcmc ------------------------------------------------
m.fit(model)
m.summary()
library(BayesianInference)

# Setup platform ------------------------------------------------
m = importBI(platform='cpu')

# Load data β€” column args mapped automatically from the DataFrame.
# N_groups and N_regions are dataset-specific; provide them as defaults.
data_path = m$load$sim_nested_effects(only_path=TRUE)
m$data(data_path)

# Define model ------------------------------------------------
model <- function(y, x, group_id, region_id, N_groups=20L, N_regions=5L) {
  sigma = m$dist$exponential(1, name='sigma')

  veff = m$effects$nested_varying_effects(
    N_vars    = 2L,
    names     = list('region', 'group'),
    N_groups  = list(N_regions, N_groups),
    group_ids = list(region_id, group_id),
    centered  = FALSE,
  )
  a_g_est = veff[[1]]
  b_g_est = veff[[2]]

  mu = a_g_est + b_g_est * x
  m$dist$normal(mu, sigma, obs=y)
}

# Run MCMC ------------------------------------------------
m$fit(model)
m$summary()

Mathematical Details

We model the outcome Y_i for observation i. The intercept \alpha_j and slope \beta_j for group j(i) are nested within region k(j).

Y_{i} \sim \text{Normal}(\mu_i, \sigma)

\mu_{i} = \alpha_{[j(i)]} + \beta_{[j(i)]} X_i

Centered Parameterization

The vector of group-level effects follows a Multivariate Normal distribution centered on the region-level effects:

\begin{bmatrix} \alpha_j \\ \beta_j \end{bmatrix} \sim \text{MVNormal}\left(\begin{bmatrix} \gamma_{\alpha, [k(j)]} \\ \gamma_{\beta, [k(j)]} \end{bmatrix}, \Sigma_{group}\right)

The region-level effects themselves follow their own Multivariate Normal distribution:

\begin{bmatrix} \gamma_{\alpha, k} \\ \gamma_{\beta, k} \end{bmatrix} \sim \text{MVNormal}\left(\begin{bmatrix} \bar{\gamma}_\alpha \\ \bar{\gamma}_\beta \end{bmatrix}, \Sigma_{region}\right)

Where: - \Sigma = \mathbf{S} \mathbf{R} \mathbf{S} (diagonal matrix of standard deviations \mathbf{S} and correlation matrix \mathbf{R}). - \bar{\gamma}_\alpha, \bar{\gamma}_\beta are the global average intercept and slope. - \Sigma_{region}, \Sigma_{group} are the covariance matrices for each hierarchical level.

Non-Centered Parameterization

To avoid funnel geometries, the non-centered formulation instead uses Cholesky factors L_{region} and L_{group} alongside standardized, non-correlated normal variables:

\begin{bmatrix} z_{\alpha, k} \\ z_{\beta, k} \end{bmatrix} \sim \text{Normal}(0, 1), \quad \begin{bmatrix} z_{\alpha, j} \\ z_{\beta, j} \end{bmatrix} \sim \text{Normal}(0, 1)

The region-level and group-level distributions are linearly transformed:

\begin{bmatrix} \gamma_{\alpha, k} \\ \gamma_{\beta, k} \end{bmatrix} = \begin{bmatrix} \bar{\gamma}_\alpha \\ \bar{\gamma}_\beta \end{bmatrix} + \text{diag}(\mathbf{S}_{region}) L_{region} \begin{bmatrix} z_{\alpha, k} \\ z_{\beta, k} \end{bmatrix}

\begin{bmatrix} \alpha_j \\ \beta_j \end{bmatrix} = \begin{bmatrix} \gamma_{\alpha, [k(j)]} \\ \gamma_{\beta, [k(j)]} \end{bmatrix} + \text{diag}(\mathbf{S}_{group}) L_{group} \begin{bmatrix} z_{\alpha, j} \\ z_{\beta, j} \end{bmatrix}

Notes

Note
  • By nesting the MVN prior, we allow the model to learn how the correlation between intercepts and slopes persists across different levels of the hierarchy.
  • The non-centered parameterization (used by default in the m.effects.nested_varying_effects built-in function) is critical for convergence in these complex models.
  • This approach can be generalized to N levels and M variables by increasing the dimensionality of the vectors and matrices.

Reference(s)