Latent Variable Models (WIP)

Structural Equation Modeling
Factor Analysis
Models that relate a set of observable variables to a set of unobserved (latent) variables.

General Principles

In some scenarios, the observed data does not directly reflect the underlying structure or factors influencing the outcome. Instead, latent variablesβ€”variables that are not directly observed but are inferred from the dataβ€”can help model this hidden structure. These latent variables capture unobserved factors that affect the relationship between predictors (X) and the outcome (Y).

We model the relationship between the predictor variables (X) and the outcome variable (Y) with a latent variable (Z) as follows:

Y = f(X, Z) + \epsilon

Where: - Y is the observed outcome variable. - X is the observed predictor variable(s). - Z is the latent (unobserved) variable, which we aim to infer. - f(X, Z) is the function that relates X and Z to Y. - is the error term, typically assumed to be normally distributed with mean 0 and variance ^2.

The latent variable Z can represent various phenomena, such as group-level effects, time-varying trends, or individual-level factors, that are not captured by the observed predictors alone.

Considerations

In Bayesian regression with latent variables, we consider the uncertainty in both the observed and latent variables. We declare prior distributions for the latent variables, in addition to the usual priors for regression coefficients and intercepts. These latent variables are often modeled using Gaussian distributions (Normal priors) or more flexible distributions such as Multivariate Normal for correlations among the latent variables.

The goal is to infer the posterior distribution over both the parameters and the latent variables, given the observed data.

Example

Below is an example code snippet demonstrating Bayesian regression with latent variables using TensorFlow Probability:

Code
from BI import bi, jnp


# Setup device------------------------------------------------
m = bi(platform='cpu')

# Data Simulation ------------------------------------------------
NY = 4  # Number of dependent variables or outcomes (e.g., dimensions for latent variables)
NV = 8  # Number of observations or individual-level data points (e.g., subjects)
N = 100
K = 5
a = 0.5
# Generate the means and offsets for the data
# means: Generate random normal means for each of the NY outcomes
# offsets: Generate random normal offsets for each of the NV observations
means = m.dist.normal(0, 1, shape=(NY,), sample=True, seed=10)
offsets = m.dist.normal(0, 1, shape=(NV, 1), sample=True, seed=20)

Y2 = offsets + means

# Simulate individual-level random effects (e.g., random slopes or intercepts)
# b_individual: A matrix of size (N, K) where N is the number of individuals and K is the number of covariates
b_individual = m.dist.normal(0, 1, shape=(N, K), sample=True, seed=0)

# mu: Add an additional effect 'a' to the individual-level random effects 'b_individual'
# 'a' could represent a population-level effect or a baseline
mu = b_individual + a

# Convert Y2 to a JAX array for further computation in a JAX-based framework
Y2 = jnp.array(Y2)


# Set data ------------------------------------------------
dat = dict(
    NY = NY,
    NV = NV,
    Y2 = Y2
)
m.data_on_model = dat

# Define model ------------------------------------------------
def model(NY, NV, Y2):
    means = m.dist.normal(0, 1, shape=(NY,), name='means')
    offset = m.dist.normal(0, 1, shape=(NV, 1), name='offset')
    sigma = m.dist.exponential(1, shape=(NY,), name='sigma')
    tmp = jnp.tile(means, (NV, 1)).reshape(NV, NY)
    mu_l = tmp + offset
    m.dist.normal(mu_l, jnp.tile(sigma, [NV, 1]), obs=Y2)

# Run sampler ------------------------------------------------
m.fit(model)

# Summary ------------------------------------------------
m.summary()
jax.local_device_count 16
  0%|          | 0/1000 [00:00<?, ?it/s]warmup:   0%|          | 1/1000 [00:01<18:07,  1.09s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   6%|β–Œ         | 60/1000 [00:01<00:13, 68.76it/s, 1023 steps of size 4.55e-04. acc. prob=0.72]warmup:   9%|β–‰         | 93/1000 [00:01<00:09, 94.60it/s, 7 steps of size 4.40e-04. acc. prob=0.74]   warmup:  12%|β–ˆβ–        | 120/1000 [00:01<00:08, 107.96it/s, 1023 steps of size 9.35e-04. acc. prob=0.75]warmup:  14%|β–ˆβ–        | 142/1000 [00:01<00:07, 118.36it/s, 15 steps of size 6.16e-04. acc. prob=0.75]  warmup:  16%|β–ˆβ–Œ        | 162/1000 [00:01<00:06, 130.15it/s, 1023 steps of size 9.77e-04. acc. prob=0.76]warmup:  18%|β–ˆβ–Š        | 181/1000 [00:01<00:05, 139.42it/s, 915 steps of size 1.48e-03. acc. prob=0.76] warmup:  20%|β–ˆβ–ˆ        | 205/1000 [00:02<00:05, 158.31it/s, 1023 steps of size 8.00e-04. acc. prob=0.76]warmup:  22%|β–ˆβ–ˆβ–Ž       | 225/1000 [00:02<00:05, 144.31it/s, 1023 steps of size 4.95e-04. acc. prob=0.77]warmup:  25%|β–ˆβ–ˆβ–       | 246/1000 [00:02<00:04, 158.85it/s, 543 steps of size 5.48e-04. acc. prob=0.77] warmup:  26%|β–ˆβ–ˆβ–‹       | 265/1000 [00:02<00:04, 162.31it/s, 279 steps of size 1.12e-03. acc. prob=0.77]warmup:  28%|β–ˆβ–ˆβ–Š       | 283/1000 [00:02<00:04, 154.84it/s, 1023 steps of size 1.16e-03. acc. prob=0.77]warmup:  30%|β–ˆβ–ˆβ–ˆ       | 300/1000 [00:02<00:04, 151.17it/s, 351 steps of size 5.91e-04. acc. prob=0.77] warmup:  32%|β–ˆβ–ˆβ–ˆβ–      | 316/1000 [00:02<00:04, 151.36it/s, 1023 steps of size 7.57e-04. acc. prob=0.77]warmup:  33%|β–ˆβ–ˆβ–ˆβ–Ž      | 332/1000 [00:02<00:04, 152.28it/s, 1023 steps of size 1.36e-04. acc. prob=0.77]warmup:  35%|β–ˆβ–ˆβ–ˆβ–      | 348/1000 [00:02<00:04, 147.46it/s, 900 steps of size 2.77e-04. acc. prob=0.77] warmup:  36%|β–ˆβ–ˆβ–ˆβ–‹      | 364/1000 [00:03<00:04, 148.45it/s, 983 steps of size 1.63e-04. acc. prob=0.77]warmup:  38%|β–ˆβ–ˆβ–ˆβ–Š      | 380/1000 [00:03<00:04, 142.19it/s, 711 steps of size 2.59e-04. acc. prob=0.77]warmup:  40%|β–ˆβ–ˆβ–ˆβ–‰      | 396/1000 [00:03<00:04, 143.13it/s, 1023 steps of size 1.26e-04. acc. prob=0.77]warmup:  41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 411/1000 [00:03<00:04, 138.67it/s, 1023 steps of size 1.64e-04. acc. prob=0.78]warmup:  42%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 425/1000 [00:03<00:04, 138.03it/s, 1023 steps of size 2.12e-04. acc. prob=0.78]warmup:  44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 440/1000 [00:03<00:04, 138.41it/s, 1023 steps of size 1.13e-04. acc. prob=0.78]warmup:  46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 455/1000 [00:03<00:03, 139.39it/s, 1023 steps of size 7.86e-05. acc. prob=0.77]warmup:  47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 469/1000 [00:03<00:04, 130.32it/s, 1023 steps of size 2.29e-04. acc. prob=0.78]warmup:  49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 488/1000 [00:03<00:03, 143.87it/s, 1023 steps of size 1.52e-04. acc. prob=0.78]sample:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 503/1000 [00:04<00:03, 136.15it/s, 7 steps of size 1.66e-04. acc. prob=0.96]   sample:  52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 517/1000 [00:04<00:03, 128.91it/s, 1023 steps of size 1.66e-04. acc. prob=0.94]sample:  53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 533/1000 [00:04<00:03, 134.93it/s, 1023 steps of size 1.66e-04. acc. prob=0.91]sample:  55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 547/1000 [00:04<00:03, 135.64it/s, 863 steps of size 1.66e-04. acc. prob=0.85] sample:  56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 561/1000 [00:04<00:03, 130.06it/s, 1023 steps of size 1.66e-04. acc. prob=0.80]sample:  58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 576/1000 [00:04<00:03, 135.02it/s, 1023 steps of size 1.66e-04. acc. prob=0.71]sample:  59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 591/1000 [00:04<00:02, 136.83it/s, 722 steps of size 1.66e-04. acc. prob=0.68] sample:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 605/1000 [00:04<00:03, 131.32it/s, 384 steps of size 1.66e-04. acc. prob=0.64]sample:  63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 628/1000 [00:04<00:02, 158.49it/s, 11 steps of size 1.66e-04. acc. prob=0.58] sample:  71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 710/1000 [00:05<00:00, 345.40it/s, 39 steps of size 1.66e-04. acc. prob=0.39]sample:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 746/1000 [00:05<00:01, 230.55it/s, 1023 steps of size 1.66e-04. acc. prob=0.40]sample:  78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 775/1000 [00:05<00:01, 190.15it/s, 3 steps of size 1.66e-04. acc. prob=0.42]   sample:  80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 799/1000 [00:05<00:01, 132.88it/s, 63 steps of size 1.66e-04. acc. prob=0.42]sample:  82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 818/1000 [00:06<00:01, 121.17it/s, 594 steps of size 1.66e-04. acc. prob=0.42]sample:  83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 834/1000 [00:06<00:01, 118.71it/s, 311 steps of size 1.66e-04. acc. prob=0.43]sample:  85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 849/1000 [00:06<00:01, 107.97it/s, 3 steps of size 1.66e-04. acc. prob=0.44]  sample:  86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 862/1000 [00:06<00:01, 96.05it/s, 1023 steps of size 1.66e-04. acc. prob=0.45]sample:  88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 883/1000 [00:06<00:01, 115.66it/s, 1023 steps of size 1.66e-04. acc. prob=0.43]sample:  90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 899/1000 [00:06<00:00, 121.98it/s, 1023 steps of size 1.66e-04. acc. prob=0.42]sample:  91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 913/1000 [00:07<00:00, 123.84it/s, 1023 steps of size 1.66e-04. acc. prob=0.42]sample:  93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 927/1000 [00:07<00:00, 125.72it/s, 458 steps of size 1.66e-04. acc. prob=0.41] sample:  94%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 941/1000 [00:07<00:00, 116.93it/s, 3 steps of size 1.66e-04. acc. prob=0.41]  sample:  95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 954/1000 [00:07<00:00, 101.22it/s, 1023 steps of size 1.66e-04. acc. prob=0.41]sample:  96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 965/1000 [00:07<00:00, 81.61it/s, 967 steps of size 1.66e-04. acc. prob=0.41]  sample:  98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 975/1000 [00:07<00:00, 78.31it/s, 479 steps of size 1.66e-04. acc. prob=0.41]sample:  98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 985/1000 [00:07<00:00, 81.42it/s, 667 steps of size 1.66e-04. acc. prob=0.40]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 996/1000 [00:08<00:00, 87.94it/s, 1023 steps of size 1.66e-04. acc. prob=0.40]sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [00:08<00:00, 123.85it/s, 1023 steps of size 1.66e-04. acc. prob=0.40]
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat
means[0] -1.45 0.0 -1.45 -1.45 0.0 0.0 1.35 14.79 NaN
means[1] -0.73 0.0 -0.73 -0.72 0.0 0.0 1.35 14.79 NaN
means[2] 1.24 0.0 1.24 1.25 0.0 0.0 1.35 14.79 NaN
means[3] -0.60 0.0 -0.61 -0.60 0.0 0.0 1.34 15.50 NaN
offset[0, 0] -0.64 0.0 -0.64 -0.64 0.0 0.0 1.35 14.79 NaN
offset[1, 0] 0.75 0.0 0.75 0.76 0.0 0.0 1.34 14.79 NaN
offset[2, 0] -1.14 0.0 -1.14 -1.13 0.0 0.0 1.35 14.79 NaN
offset[3, 0] 0.10 0.0 0.09 0.10 0.0 0.0 1.35 14.79 NaN
offset[4, 0] -0.29 0.0 -0.30 -0.29 0.0 0.0 1.35 14.79 NaN
offset[5, 0] -0.80 0.0 -0.81 -0.80 0.0 0.0 1.34 14.79 NaN
offset[6, 0] 0.17 0.0 0.17 0.17 0.0 0.0 1.35 14.79 NaN
offset[7, 0] 0.75 0.0 0.75 0.75 0.0 0.0 1.35 14.79 NaN
sigma[0] 0.00 0.0 0.00 0.00 0.0 0.0 1.41 10.90 NaN
sigma[1] 0.00 0.0 0.00 0.00 0.0 0.0 12.57 10.86 NaN
sigma[2] 0.00 0.0 0.00 0.00 0.0 0.0 17.85 12.09 NaN
sigma[3] 0.00 0.0 0.00 0.00 0.0 0.0 1.64 11.56 NaN

Mathematical Details