import arviz as az
import numpy as np
import pymc as pm
print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.15.1+68.gc0b060b98.dirty
az.style.use("arviz-darkgrid")
Model comparison#
To demonstrate the use of model comparison criteria in PyMC, we implement the 8 schools example from Section 5.5 of Gelman et al (2003), which attempts to infer the effects of coaching on SAT scores of students from 8 schools. Below, we fit a pooled model, which assumes a single fixed effect across all schools, and a hierarchical model that allows for a random effect that partially pools the data.
The data include the observed treatment effects (y) and associated standard deviations (sigma) in the 8 schools.
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
J = len(y)
Pooled model#
with pm.Model() as pooled:
    # Latent pooled effect size
    mu = pm.Normal("mu", 0, sigma=1e6)
    obs = pm.Normal("obs", mu, sigma=sigma, observed=y)
    trace_p = pm.sample(2000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 6 seconds.
az.plot_trace(trace_p);
 
Hierarchical model#
with pm.Model() as hierarchical:
    eta = pm.Normal("eta", 0, 1, shape=J)
    # Hierarchical mean and SD
    mu = pm.Normal("mu", 0, sigma=10)
    tau = pm.HalfNormal("tau", 10)
    # Non-centered parameterization of random effect
    theta = pm.Deterministic("theta", mu + tau * eta)
    obs = pm.Normal("obs", theta, sigma=sigma, observed=y)
    trace_h = pm.sample(2000, target_accept=0.9)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [eta, mu, tau]
Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 24 seconds.
az.plot_trace(trace_h, var_names="mu");
 
az.plot_forest(trace_h, var_names="theta");
 
Leave-one-out Cross-validation (LOO)#
LOO cross-validation is an estimate of the out-of-sample predictive fit. In cross-validation, the data are repeatedly partitioned into training and holdout sets, iteratively fitting the model with the former and evaluating the fit with the holdout data. Vehtari et al. (2016) introduced an efficient computation of LOO from MCMC samples (without the need for re-fitting the data). This approximation is based on importance sampling. The importance weights are stabilized using a method known as Pareto-smoothed importance sampling (PSIS).
Widely-applicable Information Criterion (WAIC)#
WAIC (Watanabe 2010) is a fully Bayesian criterion for estimating out-of-sample expectation, using the computed log pointwise posterior predictive density (LPPD) and correcting for the effective number of parameters to adjust for overfitting.
By default ArviZ uses LOO, but WAIC is also available.
Model log-likelihood#
In order to compute LOO and WAIC, ArviZ needs access to the model elemwise loglikelihood for every posterior sample. We can add it via compute_log_likelihood(). Alternatively we can pass idata_kwargs={"log_likelihood": True} to sample() to have it computed automatically at the end of sampling.
with pooled:
    pm.compute_log_likelihood(trace_p)
pooled_loo = az.loo(trace_p)
pooled_loo
Computed from 8000 posterior samples and 8 observations log-likelihood matrix.
         Estimate       SE
elpd_loo   -30.58     1.11
p_loo        0.69        -
------
Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)        8  100.0%
 (0.5, 0.7]   (ok)          0    0.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%
with hierarchical:
    pm.compute_log_likelihood(trace_h)
hierarchical_loo = az.loo(trace_h)
hierarchical_loo
Computed from 8000 posterior samples and 8 observations log-likelihood matrix.
         Estimate       SE
elpd_loo   -30.82     1.08
p_loo        1.17        -
------
Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.5]   (good)        4   50.0%
 (0.5, 0.7]   (ok)          4   50.0%
   (0.7, 1]   (bad)         0    0.0%
   (1, Inf)   (very bad)    0    0.0%
ArviZ includes two convenience functions to help compare LOO for different models. The first of these functions is compare, which computes LOO (or WAIC) from a set of traces and models and returns a DataFrame.
df_comp_loo = az.compare({"hierarchical": trace_h, "pooled": trace_p})
df_comp_loo
| rank | elpd_loo | p_loo | elpd_diff | weight | se | dse | warning | scale | |
|---|---|---|---|---|---|---|---|---|---|
| pooled | 0 | -30.578116 | 0.686645 | 0.000000 | 1.0 | 1.105891 | 0.000000 | False | log | 
| hierarchical | 1 | -30.820005 | 1.167010 | 0.241889 | 0.0 | 1.080954 | 0.231679 | False | log | 
We have many columns, so let’s check out their meaning one by one:
- The index is the names of the models taken from the keys of the dictionary passed to - compare(.).
- rank, the ranking of the models starting from 0 (best model) to the number of models. 
- loo, the values of LOO (or WAIC). The DataFrame is always sorted from best LOO/WAIC to worst. 
- p_loo, the value of the penalization term. We can roughly think of this value as the estimated effective number of parameters (but do not take that too seriously). 
- d_loo, the relative difference between the value of LOO/WAIC for the top-ranked model and the value of LOO/WAIC for each model. For this reason we will always get a value of 0 for the first model. 
- weight, the weights assigned to each model. These weights can be loosely interpreted as the probability of each model being true (among the compared models) given the data. 
- se, the standard error for the LOO/WAIC computations. The standard error can be useful to assess the uncertainty of the LOO/WAIC estimates. By default these errors are computed using stacking. 
- dse, the standard errors of the difference between two values of LOO/WAIC. The same way that we can compute the standard error for each value of LOO/WAIC, we can compute the standard error of the differences between two values of LOO/WAIC. Notice that both quantities are not necessarily the same, the reason is that the uncertainty about LOO/WAIC is correlated between models. This quantity is always 0 for the top-ranked model. 
- warning, If - Truethe computation of LOO/WAIC may not be reliable.
- loo_scale, the scale of the reported values. The default is the log scale as previously mentioned. Other options are deviance – this is the log-score multiplied by -2 (this reverts the order: a lower LOO/WAIC will be better) – and negative-log – this is the log-score multiplied by -1 (as with the deviance scale, a lower value is better). 
The second convenience function takes the output of compare and produces a summary plot in the style of the one used in the book Statistical Rethinking by Richard McElreath (check also this port of the examples in the book to PyMC).
az.plot_compare(df_comp_loo, insample_dev=False);
 
The empty circle represents the values of LOO and the black error bars associated with them are the values of the standard deviation of LOO.
The value of the highest LOO, i.e the best estimated model, is also indicated with a vertical dashed grey line to ease comparison with other LOO values.
For all models except the top-ranked one we also get a triangle indicating the value of the difference of WAIC between that model and the top model and a grey error bar indicating the standard error of the differences between the top-ranked WAIC and WAIC for each model.
Interpretation#
Though we might expect the hierarchical model to outperform a complete pooling model, there is little to choose between the models in this case, given that both models gives very similar values of the information criteria. This is more clearly appreciated when we take into account the uncertainty (in terms of standard errors) of LOO and WAIC.
Reference#
%load_ext watermark
%watermark -n -u -v -iv -w -p xarray,pytensor
Last updated: Tue Jun 25 2024
Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2
xarray  : 2024.2.0
pytensor: 2.20.0+3.g66439d283.dirty
matplotlib: 3.8.3
numpy     : 1.26.4
pymc      : 5.15.0+1.g58927d608
arviz     : 0.17.1
Watermark: 2.4.3