from BI import biimport jax.numpy as jnp# setup platform------------------------------------------------m = bi(platform='cpu')# import data ------------------------------------------------m.data('Howell1.csv', sep=';') m.df = m.df[m.df.age >18]m.scale(data=['weight'])# define model ------------------------------------------------def model(weight, height): a = m.dist.normal( 178, 20, name ='a') b = m.dist.log_normal( 0, 1, name ='b') s = m.dist.uniform( 0, 50, name ='s') m.dist.normal(a + b * weight , s, obs=height, shape=(weight.shape[0],))# Run sampler ------------------------------------------------m.fit(model, num_samples=500,num_chains=4) m.summary()
Predictions from model based on specific data value
Code
m.sample() # Predictions from model base on data in data_on_modelm.sample(data=dict(weight=jnp.array([0.4])), remove_obs=False)# Predictions from a given value
/home/sosa/work/BI/BI/Main/main.py:417: UserWarning:
Sample's batch dimension size 2000 is different from the provided 1 num_samples argument. Defaulting to 2000.