arviz_stats.loo_moment_match

Contents

arviz_stats.loo_moment_match#

arviz_stats.loo_moment_match(data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=False, pointwise=None)[source]#

Compute moment matching for problematic observations in PSIS-LOO-CV.

Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior.

The moment matching algorithm is described in [1] and the PSIS-LOO-CV method is described in [2] and [3].

Parameters:
dataxarray.DataTree or InferenceData

Input data. It should contain the posterior and the log_likelihood groups.

loo_origELPDData

An existing ELPDData object from a previous loo result. Must contain pointwise Pareto k values (pointwise=True must have been used).

log_prob_upars_fncallable

Function that computes the log probability density of the full posterior distribution evaluated at unconstrained parameter draws. The function signature is log_prob_upars_fn(upars) where upars is a DataArray of unconstrained parameter draws with dimensions chain, draw, and a parameter dimension. It should return a DataArray with dimensions chain, draw.

log_lik_i_upars_fncallable

Function that computes the log-likelihood of a single left-out observation evaluated at unconstrained parameter draws. The function signature is log_lik_i_upars_fn(upars, i) where upars is a DataArray of unconstrained parameter draws and i is the integer index of the left-out observation. It should return a DataArray with dimensions chain, draw.

uparsxarray.DataArray, optional

Posterior draws transformed to the unconstrained parameter space. Must have chain and draw dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the unconstrained_posterior group from the input data if available.

var_namestr, optional

The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for loo computation.

reff: float, optional

Relative MCMC efficiency, ess / n i.e. number of effective samples divided by the number of actual samples. Computed from trace by default.

max_itersint, default 30

Maximum number of moment matching iterations for each problematic observation.

k_thresholdfloat, optional

Threshold value for Pareto k values above which moment matching is applied. Defaults to \(\min(1 - 1/\log_{10}(S), 0.7)\), where S is the number of samples.

splitbool, default True

If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws.

covbool, default False

If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. Ignored if split=False.

pointwise: bool, optional

If True, the pointwise predictive accuracy will be returned. Defaults to rcParams["stats.ic_pointwise"]. Moment matching always requires pointwise data from loo_orig. This argument controls whether the returned object includes pointwise data.

Returns:
ELPDData

Object with the following attributes:

  • elpd: expected log pointwise predictive density

  • se: standard error of the elpd

  • p: effective number of parameters

  • n_samples: number of samples

  • n_data_points: number of data points

  • warning: True if the estimated shape parameter of Pareto distribution is greater than good_k.

  • elp_i: DataArray with the pointwise predictive accuracy, only if pointwise=True

  • pareto_k: array of Pareto shape values, only if pointwise=True

  • good_k: For a sample size S, the threshold is computed as min(1 - 1/log10(S), 0.7)

  • approx_posterior: True if approximate posterior was used.

See also

loo

Standard PSIS-LOO-CV.

reloo

Exact re-fitting for problematic observations.

Notes

The moment matching algorithm considers three affine transformations of the posterior draws: For a specific draw \(\theta^{(s)}\), a generic affine transformation includes a square matrix \(\mathbf{A}\) representing a linear map and a vector \(\mathbf{b}\) representing a translation such that

\[T : \theta^{(s)} \mapsto \mathbf{A}\theta^{(s)} + \mathbf{b} =: \theta^{*{(s)}}.\]

The first transformation, \(T_1\), is a translation that matches the mean of the sample to its importance weighted mean given by

\[\mathbf{\theta^{*{(s)}}} = T_1(\mathbf{\theta^{(s)}}) = \mathbf{\theta^{(s)}} - \bar{\theta} + \bar{\theta}_w,\]

where \(\bar{\theta}\) is the mean of the sample and \(\bar{\theta}_w\) is the importance weighted mean of the sample. The second transformation, \(T_2\), is a scaling that matches the marginal variances in addition to the means given by

\[\mathbf{\theta^{*{(s)}}} = T_2(\mathbf{\theta^{(s)}}) = \mathbf{v}^{1/2}_w \circ \mathbf{v}^{-1/2} \circ (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]

where \(\mathbf{v}\) and \(\mathbf{v}_w\) are the sample and weighted variances, and \(\circ\) denotes the pointwise product of the elements of two vectors. The third transformation, \(T_3\), is a covariance transformation that matches the covariance matrix of the sample to its importance weighted covariance matrix given by

\[\mathbf{\theta^{*{(s)}}} = T_3(\mathbf{\theta^{(s)}}) = \mathbf{L}_w \mathbf{L}^{-1} (\mathbf{\theta^{(s)}} - \bar{\theta}) + \bar{\theta}_w,\]

where \(\mathbf{L}\) and \(\mathbf{L}_w\) are the Cholesky decompositions of the covariance matrix and the weighted covariance matrix, respectively, e.g.,

\[\mathbf{LL}^T = \mathbf{\Sigma} = \frac{1}{S} \sum_{s=1}^S (\mathbf{\theta^{(s)}} - \bar{\theta}) (\mathbf{\theta^{(s)}} - \bar{\theta})^T\]

and

\[\mathbf{L}_w \mathbf{L}_w^T = \mathbf{\Sigma}_w = \frac{\frac{1}{S} \sum_{s=1}^S w^{(s)} (\mathbf{\theta^{(s)}} - \bar{\theta}_w) (\mathbf{\theta^{(s)}} - \bar{\theta}_w)^T}{\sum_{s=1}^S w^{(s)}}.\]

We iterate on \(T_1\) repeatedly and move onto \(T_2\) and \(T_3\) only if \(T_1\) fails to yield a Pareto-k statistic below the threshold.

References

[1]

Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850.

[2]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[3]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Moment matching can improve PSIS-LOO-CV estimates for observations with high Pareto k values without having to refit the model for each problematic observation. We will use the non-centered eight schools data which has 1 problematic observation. In practice, moment matching is useful when you have a potentially large number of problematic observations:

In [1]: import arviz_base as azb
   ...: import numpy as np
   ...: import xarray as xr
   ...: from scipy import stats
   ...: from arviz_stats import loo
   ...: 
   ...: idata = azb.load_arviz_data("non_centered_eight")
   ...: posterior = idata.posterior
   ...: schools = posterior.theta_t.coords["school"].values
   ...: y_obs = idata.observed_data.obs
   ...: obs_dim = y_obs.dims[0]
   ...: 
   ...: loo_orig = loo(idata, pointwise=True, var_name="obs")
   ...: loo_orig
   ...: 
Out[1]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

         Estimate       SE
elpd_loo   -30.72     1.33
p_loo        0.90        -

There has been a warning during the calculation. Please check the results.
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        7   87.5%
   (0.70, 1]   (bad)         1   12.5%
    (1, Inf)   (very bad)    0    0.0%

The moment matching algorithm applies affine transformations to posterior draws in unconstrained parameter space. To enable this, we need to collect the posterior parameters from their original space, transform them to unconstrained space if needed, and stack them into a single xarray.DataArray that matches the expected (chain, draw, param) structure. Some parameters may already be in unconstrained space, so we don’t need to transform them. This will depend on the model and the choice of parameterization:

In [2]: upars_ds = xr.Dataset(
   ...:     {
   ...:         **{
   ...:             f"theta_t_{school}": posterior.theta_t.sel(school=school, drop=True)
   ...:             for school in schools
   ...:         },
   ...:         "mu": posterior.mu,
   ...:         "log_tau": xr.apply_ufunc(np.log, posterior.tau),
   ...:     }
   ...: )
   ...: upars = azb.dataset_to_dataarray(
   ...:     upars_ds, sample_dims=["chain", "draw"], new_dim="upars_dim"
   ...: )
   ...: 

Moment matching requires two functions: one for the joint log probability (likelihood + priors) and another for the pointwise log-likelihood of a single observation. We first define functions that accept the data they need as keyword-only arguments:

In [3]: sigmas = xr.DataArray(
   ...:     [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0],
   ...:     dims=[obs_dim],
   ...: )
   ...: 
   ...: def log_prob_upars(upars, *, sigmas, y, schools, obs_dim):
   ...:     theta_t = xr.concat(
   ...:         [upars.sel(upars_dim=f"theta_t_{school}") for school in schools],
   ...:         dim=obs_dim,
   ...:     )
   ...:     mu = upars.sel(upars_dim="mu")
   ...:     log_tau = upars.sel(upars_dim="log_tau")
   ...:     tau = xr.apply_ufunc(np.exp, log_tau)
   ...:     theta = mu + tau * theta_t
   ...: 
   ...:     log_prior = xr.apply_ufunc(stats.norm(0, 5).logpdf, mu)
   ...:     log_prior = log_prior + xr.apply_ufunc(
   ...:         stats.halfcauchy(0, 5).logpdf,
   ...:         tau,
   ...:     )
   ...:     log_prior = log_prior + log_tau
   ...:     log_prior = log_prior + xr.apply_ufunc(
   ...:         stats.norm(0, 1).logpdf,
   ...:         theta_t,
   ...:     ).sum(obs_dim)
   ...: 
   ...:     const = -0.5 * np.log(2 * np.pi)
   ...:     log_like = const - np.log(sigmas) - 0.5 * ((y - theta) / sigmas) ** 2
   ...:     log_like = log_like.sum(obs_dim)
   ...:     return log_prior + log_like
   ...: 
   ...: def log_lik_i_upars(upars, i, *, sigmas, y, schools, obs_dim):
   ...:     mu = upars.sel(upars_dim="mu")
   ...:     log_tau = upars.sel(upars_dim="log_tau")
   ...:     tau = xr.apply_ufunc(np.exp, log_tau)
   ...: 
   ...:     theta_t_i = upars.sel(upars_dim=f"theta_t_{schools[i]}")
   ...:     theta_i = mu + tau * theta_t_i
   ...: 
   ...:     sigma_i = sigmas.isel({obs_dim: i})
   ...:     y_i = y.isel({obs_dim: i})
   ...:     const = -0.5 * np.log(2 * np.pi)
   ...:     return const - np.log(sigma_i) - 0.5 * ((y_i - theta_i) / sigma_i) ** 2
   ...: 

Now, we can specialise these functions with functools.partial so the resulting functions match the signature expected by loo_moment_match:

In [4]: from functools import partial
   ...: log_prob_fn = partial(
   ...:     log_prob_upars,
   ...:     sigmas=sigmas,
   ...:     y=y_obs,
   ...:     schools=schools,
   ...:     obs_dim=obs_dim,
   ...: )
   ...: log_lik_i_fn = partial(
   ...:     log_lik_i_upars,
   ...:     sigmas=sigmas,
   ...:     y=y_obs,
   ...:     schools=schools,
   ...:     obs_dim=obs_dim,
   ...: )
   ...: 

Finally, we can run moment matching using the prepared inputs. Now, we have no problematic observations anymore:

In [5]: from arviz_stats import loo_moment_match
   ...: loo_mm = loo_moment_match(
   ...:     idata,
   ...:     loo_orig,
   ...:     upars=upars,
   ...:     log_prob_upars_fn=log_prob_fn,
   ...:     log_lik_i_upars_fn=log_lik_i_fn,
   ...:     var_name="obs",
   ...:     split=True,
   ...: )
   ...: loo_mm
   ...: 
Out[5]: 
Computed from 2000 posterior samples and 8 observations log-likelihood matrix.

         Estimate       SE
elpd_loo   -30.69     1.43
p_loo        0.88        -
------

Pareto k diagnostic values:
                         Count   Pct.
(-Inf, 0.70]   (good)        8  100.0%
   (0.70, 1]   (bad)         0    0.0%
    (1, Inf)   (very bad)    0    0.0%