9  Moment matching for improved PSIS-LOO-CV

In this chapter, we demonstrate how to apply importance weighted moment matching (IWMM) to improve PSIS-LOO-CV estimates when influential observations produce high Pareto-\(k\) values. We walk through a complete applied example using the roaches dataset from Gelman and Hill (2007), showing how to specify the required functions and apply moment matching to resolve problematic importance sampling approximations without expensive model refitting. The workflow demonstrates ArviZ’s loo_moment_match() function with a deliberately misspecified Poisson regression model where multiple observations are flagged as problematic. For theoretical background and details of how moment matching works, we recommend you read Section 7.7 and references therein.

9.1 Roaches data and Poisson regression model

Let’s now walk through a concrete example using the roaches dataset to see this algorithm in action. The roaches dataset from Gelman and Hill (2007) examines the efficacy of a pest management system at reducing cockroach infestations in urban apartments. The study followed 264 apartments over several months, recording the number of roaches caught during follow-up (y), pre-treatment roach counts (roach1, square root transformed), treatment status (treatment), whether the building is restricted to elderly residents (senior), and trap exposure time in days (exposure2).

We intentionally use a Poisson model rather than negative binomial regression to demonstrate how moment matching handles misspecified models. The exposure time varies across apartments, so we include log(exposure2) as an offset term.

9.1.1 Loading and preparing the data

We start by loading the roaches dataset and examining its basic structure.

import pandas as pd

roaches = pd.read_csv("../data/roaches.csv", index_col=0)
roaches['log_exposure2'] = np.log(roaches['exposure2'])

roaches.describe()
y roach1 treatment senior exposure2 log_exposure2
count 262.000000 262.000000 262.000000 262.000000 262.000000 262.000000
mean 25.648855 4.419861 0.603053 0.305344 1.021047 -0.012538
std 50.846539 4.769184 0.490201 0.461434 0.320757 0.249755
min 0.000000 0.000000 0.000000 0.000000 0.200000 -1.609438
25% 0.000000 1.000000 0.000000 0.000000 1.000000 0.000000
50% 3.000000 2.645751 1.000000 0.000000 1.000000 0.000000
75% 24.000000 7.106071 1.000000 1.000000 1.000000 0.000000
max 357.000000 21.213203 1.000000 1.000000 4.285714 1.455287

The dataset exhibits substantial overdispersion with mean 25.6 and standard deviation 50.8. Approximately 36% of apartments had zero roaches while the maximum reached 357. This combination of high variability, many zeros, and extreme counts makes certain observations potentially influential, which is precisely where moment matching becomes valuable for improving LOO-CV estimates.

9.1.2 Model specification and fitting

We fit a Poisson regression model using Bambi, which provides a high-level interface for Bayesian modeling. The model includes the pre-treatment roach count, treatment indicator, and senior building indicator as predictors, with log exposure as an offset.

import bambi as bmb

model = bmb.Model(
    'y ~ roach1 + treatment + senior + offset(log_exposure2)',
    data=roaches,
    family='poisson',
    priors={
        'roach1': bmb.Prior('Normal', mu=0, sigma=2.5),
        'treatment': bmb.Prior('Normal', mu=0, sigma=2.5),
        'senior': bmb.Prior('Normal', mu=0, sigma=2.5),
        'Intercept': bmb.Prior('Normal', mu=0, sigma=5.0)
    }
)

idata = model.fit(
    draws=1000,
    tune=1000,
    chains=4,
    random_seed=SEED,
    idata_kwargs={'log_likelihood': True}
)
# CmdStanPy implementation will be added in future
azp.summary(idata.posterior.ds, var_names=['roach1', 'treatment', 'senior', 'Intercept'])
mean sd eti89_lb eti89_ub ess_bulk ess_tail r_hat mcse_mean mcse_sd
roach1 0.16 0.00 0.16 0.16 3269.45 3057.47 1.0 0.0 0.0
treatment -0.57 0.03 -0.61 -0.53 3412.93 2837.12 1.0 0.0 0.0
senior -0.32 0.03 -0.37 -0.26 3535.11 2733.84 1.0 0.0 0.0
Intercept 2.53 0.03 2.49 2.57 3294.58 2853.49 1.0 0.0 0.0

9.1.3 Initial PSIS-LOO-CV evaluation

With the posterior draws in hand, we compute the PSIS-LOO-CV estimate of the model’s predictive performance.

loo_result = azp.loo(idata, pointwise=True, var_name="y")
loo_result
/opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/arviz_stats/loo/helper_loo.py:1179: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
Computed from 4000 posterior samples and 262 observations log-likelihood matrix.

         Estimate       SE
elpd_loo -5461.14   693.85
p_loo      258.87        -

There has been a warning during the calculation. Please check the results.
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)      249   95.0%
   (0.70, 1]   (bad)         6    2.3%
    (1, Inf)   (very bad)    7    2.7%

The output shows that 13 observations have Pareto-\(k\) values exceeding 0.7, indicating that the importance sampling approximation is unreliable for these cases. For these problematic observations, the most accurate (but computationally expensive) approach would be to refit the model, leaving out each observation in turn. However, this quickly becomes impractical as the number of flagged observations grows. Moment matching offers an efficient solution by improving the reliability of the PSIS-LOO-CV estimates for exactly those cases where brute-force refitting would otherwise be required, reducing computational burden without sacrificing much accuracy.

9.2 Applying moment matching

Now we apply importance weighted moment matching to improve the PSIS-LOO-CV estimates for the problematic observations. ArviZ implements this through the loo_moment_match() function, which requires us to provide the unconstrained parameter draws and functions for evaluating the model’s log posterior and log likelihood.

9.2.1 Function specification and implementation

To apply moment matching, we need to provide functions that compute the log posterior \(\log p(\theta \mid y)\) and the log likelihood for a single observation \(\log p(y_i \mid \theta)\) in the unconstrained parameter space, where all parameters are real-valued without constraints.

Our Poisson regression model is

\[ y_i \sim \text{Poisson}(\lambda_i), \quad \lambda_i = \exp(\eta_i), \]

where the linear predictor is

\[ \eta_i = \alpha + \sum_{j=1}^3 \beta_j x_{ij} + \log(\text{exposure2}_i), \]

with \(x_i = (\text{roach1}_i, \text{treatment}_i, \text{senior}_i)\). We use independent normal priors \(\beta_j \sim N(0, 2.5^2)\) and \(\alpha \sim N(0, 5^2)\).

The log posterior combines the log likelihood with the log prior

\[ \log p(\theta \mid y) = \sum_{i=1}^n \log p(y_i \mid \theta) + \log p(\theta), \]

where the log likelihood for observation \(i\) is

\[ \log p(y_i \mid \theta) = y_i \log(\lambda_i) - \lambda_i - \log(y_i!). \]

We start by converting our data into DataArrays with properly labeled dimensions to use xarray’s dimension-aware operations for the computations.

import xarray as xr
from scipy.special import gammaln
from functools import partial

X = roaches[['roach1', 'treatment', 'senior']].values
y = roaches['y'].values
log_offset = roaches['log_exposure2'].values

n_obs = len(roaches)
coef_names = ['roach1', 'treatment', 'senior']

design_matrix = xr.DataArray(
    X,
    dims=['obs_id', 'coef'],
    coords={'obs_id': range(n_obs), 'coef': coef_names}
)
y_da = xr.DataArray(y, dims=['obs_id'], coords={'obs_id': range(n_obs)})
offset_da = xr.DataArray(log_offset, dims=['obs_id'], coords={'obs_id': range(n_obs)})
factorial_term = xr.DataArray(gammaln(y + 1), dims=['obs_id'], coords={'obs_id': range(n_obs)})

beta_prior_scale = 2.5
alpha_prior_scale = 5.0

First, we construct an array of unconstrained parameters from the posterior draws. All of the parameters are unconstrained in this model, so we can simply concatenate the posterior draws along the uparam dimension.

posterior = idata.posterior
upars = xr.concat([
    posterior.ds['roach1'].expand_dims({'uparam': ['roach1']}),
    posterior.ds['treatment'].expand_dims({'uparam': ['treatment']}),
    posterior.ds['senior'].expand_dims({'uparam': ['senior']}),
    posterior.ds['Intercept'].expand_dims({'uparam': ['Intercept']})
], dim='uparam').transpose('chain', 'draw', 'uparam')

We can now define the log posterior function and the leave-one-out log likelihood function. The log posterior function computes the log posterior probability for a given set of unconstrained parameters, and the log likelihood function computes the log likelihood for a given set of unconstrained parameters and a given observation.

def log_prob_upars_fn(upars, design_matrix, y_da, offset_da, factorial_term,
                       coef_names, beta_prior_scale, alpha_prior_scale):
    """Compute log posterior for unconstrained parameters."""
    beta = upars.sel(uparam=coef_names).rename({'uparam': 'coef'})
    intercept = upars.sel(uparam='Intercept')

    lin = xr.dot(design_matrix, beta, dims='coef') + intercept + offset_da
    exp_lin = xr.ufuncs.exp(lin)
    log_lik = y_da * lin - exp_lin - factorial_term

    log_prior_beta = (-0.5 * (beta / beta_prior_scale) ** 2).sum('coef')
    log_prior_intercept = -0.5 * (intercept / alpha_prior_scale) ** 2
    return log_lik.sum('obs_id') + log_prior_beta + log_prior_intercept


def log_lik_i_upars_fn(upars, i, design_matrix, y_da, offset_da, factorial_term, coef_names):
    """Compute log likelihood for observation i."""
    beta = upars.sel(uparam=coef_names).rename({'uparam': 'coef'})
    intercept = upars.sel(uparam='Intercept')

    features_i = design_matrix.isel(obs_id=i)
    lin_i = (beta * features_i).sum('coef') + intercept + offset_da.isel(obs_id=i)

    exp_lin_i = xr.ufuncs.exp(lin_i)
    log_lik_i = y_da.isel(obs_id=i) * lin_i - exp_lin_i - factorial_term.isel(obs_id=i)
    return log_lik_i

Now we bind the data dependencies to these functions using functools.partial. This creates new functions where all the data parameters are fixed, leaving only the unconstrained parameters (and observation index for the likelihood function) as free parameters.

log_prob_upars = partial(
    log_prob_upars_fn,
    design_matrix=design_matrix,
    y_da=y_da,
    offset_da=offset_da,
    factorial_term=factorial_term,
    coef_names=coef_names,
    beta_prior_scale=beta_prior_scale,
    alpha_prior_scale=alpha_prior_scale
)

log_lik_i_upars = partial(
    log_lik_i_upars_fn,
    design_matrix=design_matrix,
    y_da=y_da,
    offset_da=offset_da,
    factorial_term=factorial_term,
    coef_names=coef_names
)

9.2.2 Computational considerations

While moment matching requires additional density evaluations compared to standard PSIS-LOO-CV, the computational cost remains modest. For each transformed draw \(\theta^{*(s)}\), we must evaluate both the full-data posterior density \(p(\theta^{*(s)} \mid y)\) and the likelihood \(p(y_i \mid \theta^{*(s)})\), rather than just the likelihood as in standard PSIS. Even with multiple transformation iterations, this cost is substantially smaller than refitting the model for each problematic observation.

The key trade-off is between accuracy and computational efficiency. When several observations have high Pareto-\(k\) values, moment matching provides reliable estimates at a fraction of the cost of exact PSIS-LOO-CV, making it practical for routine model assessment even with computationally expensive models. In our roaches example with 13 problematic observations, moment matching requires a few additional seconds of computation compared to minutes or hours that would be needed to refit the model 13 times.

9.2.3 Running moment matching

With the required functions defined, we can now run PSIS-LOO-CV with importance weighted moment matching. We will specify that we want to match the covariance structure of the posterior draws given the complicated structure of the data.

Keep in mind that the split argument, which specifies whether to use the split proposal density, is a boolean that defaults to True. This is highly recommended in most applied cases. When split=True, the split proposal density is

\[ g_{\text{split,loo}}(\theta) \propto p(\theta \mid y) + |\mathbf{J}_{T_w}|^{-1}\, p(T_w^{-1}(\theta) \mid y), \]

so each fold mixes the full-data posterior with the transformed draws from the denominator adaptation.

loo_mm_result = azp.loo_moment_match(
    idata,
    loo_result,
    log_prob_upars_fn=log_prob_upars,
    log_lik_i_upars_fn=log_lik_i_upars,
    upars=upars,
    var_name="y",
    cov=True,
)

loo_mm_result
Computed from 4000 posterior samples and 262 observations log-likelihood matrix.

         Estimate       SE
elpd_loo -5477.60   699.88
p_loo      275.34        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)      262  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%

After moment matching, the Pareto-\(k\) values improve substantially, with all 262 observations now having Pareto-\(k\) values below 0.7. This indicates that moment matching has successfully resolved all problematic importance sampling issues, providing reliable PSIS-LOO-CV estimates for every observation. This represents a complete improvement over the original 13 problematic observations. Notably, the ELPD estimate decreases from -5461.14 to -5477.60 after moment matching, indicating that the original PSIS-LOO-CV estimate was too optimistic and loo() overestimated the predictive performance.

9.3 Summary

Importance weighted moment matching provides an efficient solution for improving PSIS-LOO-CV when influential observations lead to unreliable importance sampling approximations. Rather than refitting models, IWMM transforms existing posterior draws to match importance-weighted moments in the unconstrained parameter space. The approach is fully automated, requires no user tuning, and works with arbitrary posterior samples from probabilistic programming frameworks where density evaluation is possible.

The method has important limitations worth keeping in mind:

  • It targets only first and second moments, so improvements depend on whether these moments adequately capture differences between proposal and target distributions
  • When importance weights have large variance, the computation of weighted moments can become unreliable (mitigated through weight regularization or larger sample sizes)
  • The algorithm may fail to find sufficiently helpful transformations when target and proposal distributions differ substantially in tail behavior, correlation structure, or number of modes
  • For extremely high-dimensional problems, the most sophisticated transformation (matching full covariance) may become numerically unstable
  • The split proposal approximation introduces some inefficiency by placing unnecessary probability mass in regions where the integrand is near its expectation, though this trade-off is least problematic precisely when adaptive methods are most needed

Despite these limitations, moment matching succeeds in most practical applications when the original posterior simulation was reasonably successful, providing both computational efficiency and improved reliability for model assessment without requiring complex tuning or auxiliary assumptions.