Code
# Setup device------------------------------------------------
from BayesForge import bf
import jax.numpy as jnp
# Setup device------------------------------------------------
m = bf(platform='cpu', rand_seed = False)
# Simulate data ------------------------------------------------
N = 50
individual_predictor = m.dist.normal(0,1, shape = (N,1), sample = True)
kinship = m.dist.bernoulli(0.3, shape = (N,N), sample = True)
kinship = kinship.at[jnp.diag_indices(N)].set(0)
category = m.dist.categorical(jnp.array([.25,.25,.25,.25]), sample = True, shape = (N,))
N_grp, N_by_grp = jnp.unique(category, return_counts=True)
N_grp = N_grp.shape[0]
def sim_network(kinship, individual_predictor,category):
# Intercept
B_intercept = m.net.block_model(jnp.full((N,),0), 1, jnp.array([N]), sample = True)
B_category = m.net.block_model(category, N_grp, N_by_grp, sample = True)
# SR
sr = m.net.sender_receiver(
individual_predictor,
individual_predictor,
s_mu = 0.4, r_mu = -0.4, sample = True)
# D
DR = m.net.dyadic_effect(kinship, d_sd=2.5, sample = True)
return m.dist.bernoulli(
logits = B_intercept + B_category + sr + DR,
sample = True
)
network = sim_network(m.net.mat_to_edgl(kinship), individual_predictor, category)
# Predictive model ------------------------------------------------
m.data_on_model = dict(
network = network,
dyadic_predictors = m.net.mat_to_edgl(kinship),
focal_individual_predictors = individual_predictor,
target_individual_predictors = individual_predictor,
category = category
)
def model(network, dyadic_predictors, focal_individual_predictors, target_individual_predictors,category):
N_id = focal_individual_predictors.shape[0]
# Block ---------------------------------------
B_intercept = m.net.block_model(jnp.full((N_id,),0), 1, jnp.array([N_id]), name = "B_intercept")
B_category = m.net.block_model(category, N_grp, N_by_grp, name = "B_category")
## SR shape = N individuals---------------------------------------
sr = m.net.sender_receiver(
focal_individual_predictors,
target_individual_predictors,
s_mu = 0.4, r_mu = -0.4
)
# Dyadic shape = N dyads--------------------------------------
dr = m.net.dyadic_effect(dyadic_predictors, d_sd=2.5) # Diadic effect intercept only
m.dist.bernoulli(logits = B_intercept + B_category + sr + dr, obs=network)
m.fit(model, progress_bar=False)
m.summary()/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
bf v 0.0.48 package loaded
jax.local_device_count 32
/home/sosa/work/BF/BayesForge/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:
invalid value encountered in scalar divide
| mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| b_B_category[0, 0] | -3.46 | 1.53 | -5.76 | -0.91 | 0.04 | 0.03 | 1557.10 | 2259.67 | 1.0 |
| b_B_category[0, 1] | -4.06 | 1.35 | -6.28 | -1.97 | 0.03 | 0.02 | 1639.14 | 2153.85 | 1.0 |
| b_B_category[0, 2] | -5.53 | 1.71 | -8.12 | -2.71 | 0.03 | 0.02 | 3082.87 | 2992.52 | 1.0 |
| b_B_category[0, 3] | -6.16 | 2.36 | -9.71 | -2.11 | 0.03 | 0.02 | 5780.36 | 3089.99 | 1.0 |
| b_B_category[1, 0] | -6.05 | 1.30 | -8.06 | -3.91 | 0.03 | 0.02 | 1478.20 | 2406.86 | 1.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| sr_rf[48, 1] | -1.41 | 3.26 | -6.22 | 3.85 | 0.05 | 0.04 | 4054.08 | 3011.73 | 1.0 |
| sr_rf[49, 0] | -0.65 | 1.11 | -2.30 | 1.13 | 0.03 | 0.02 | 1669.32 | 2545.31 | 1.0 |
| sr_rf[49, 1] | 2.77 | 1.38 | 0.69 | 5.06 | 0.03 | 0.02 | 1900.88 | 2241.97 | 1.0 |
| sr_sigma[0] | 1.59 | 0.42 | 0.91 | 2.17 | 0.02 | 0.02 | 379.65 | 944.27 | 1.0 |
| sr_sigma[1] | 4.25 | 0.96 | 2.80 | 5.80 | 0.05 | 0.04 | 323.73 | 675.11 | 1.0 |
5131 rows × 9 columns