from BI import biimport jax.numpy as jnp# setup platform------------------------------------------------m = bi(platform="cpu")# import data ------------------------------------------------data_path = m.load.howell1(only_path=True)m.data(data_path, sep=";")m.df = m.df[m.df.age >18]m.scale(data=["weight"])# define model ------------------------------------------------def model(weight, height): a = m.dist.normal(178, 20, name="a") b = m.dist.log_normal(0, 1, name="b") s = m.dist.uniform(0, 50, name="s") m.dist.normal(a + b * weight, s, obs=height, shape=(weight.shape[0],))# Run sampler ------------------------------------------------m.fit(model, num_samples=500, num_chains=4)m.summary()
/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
BI v 0.0.45 package loaded
jax.local_device_count 32
Predictions from model based on specific data value
Code
m.sample() # Predictions from model base on data in data_on_modelm.sample(data=dict(weight=jnp.array([0.4])), remove_obs=False)# Predictions from a given value
/home/sosa/work/BI/BI/Main/main.py:590: UserWarning:
Sample's batch dimension size 2000 is different from the provided 1 num_samples argument. Defaulting to 2000.