arviz_stats.loo_score

Contents

arviz_stats.loo_score#

arviz_stats.loo_score(data, var_name=None, log_weights=None, pareto_k=None, kind='crps', pointwise=False, round_to='2g')[source]#

Compute PWM-based CRPS/SCRPS with PSIS-LOO-CV weights.

Implements the probability-weighted-moment (PWM) identity for the continuous ranked probability score (CRPS) with Pareto-smoothed importance sampling leave-one-out (PSIS-LOO-CV) weights, but returns its negative as a maximization score (larger is better). This assumes that the PSIS-LOO-CV approximation is working well.

Specifically, the PWM identity used here is

\[\operatorname{CRPS}_{\text{loo}}(F, y) = E_{\text{loo}}\left[|X - y|\right] + E_{\text{loo}}[X] - 2\cdot E_{\text{loo}} \left[X\,F_{\text{loo}}(X') \right].\]

The PWM identity is described in [3], traditional CRPS and SCRPS are described in [1] and [2], and the PSIS-LOO-CV method is described in [4] and [5].

Parameters:
dataxarray.DataTree or InferenceData

Input data. It should contain the posterior_predictive, observed_data and log_likelihood groups.

var_namestr, optional

The name of the variable in the log_likelihood group to use. If None, the first variable in observed_data is used and assumed to match log_likelihood and posterior_predictive names.

log_weightsxarray.DataArray, optional

Smoothed log weights for PSIS-LOO-CV. Must have the same shape as the log-likelihood data. Defaults to None. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with pareto_k or both must be None.

pareto_kxarray.DataArray, optional

Pareto tail indices corresponding to the PSIS smoothing. Same shape as the log-likelihood data. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with log_weights or both must be None.

kindstr, default “crps”

The kind of score to compute. Available options are:

  • ‘crps’: continuous ranked probability score. Default.

  • ‘scrps’: scale-invariant continuous ranked probability score.

pointwisebool, default False

If True, include per-observation score values in the return object.

round_toint or str, default “2g”

If integer, number of decimal places to round the result. If string of the form "2g", number of significant digits to round the result. Use None to return raw numbers.

Returns:
collections.namedtuple

If pointwise is False (default), a namedtuple named CRPS or SCRPS with fields mean and se. If pointwise is True, the namedtuple also includes a pointwise field with per-observation values.

Notes

For a single observation with posterior-predictive draws \(x_1, \ldots, x_S\) and PSIS-LOO-CV weights \(w_i \propto \exp(\ell_i)\) normalized so that \(\sum_{i=1}^S w_i = 1\), define the PSIS-LOO-CV expectation and the left-continuous weighted CDF as

\[E_{\text{loo}}[g(X)] := \sum_{i=1}^S w_i\, g(x_i), \quad F_{\text{loo}}(x') := \sum_{i: x_i < x} w_i.\]

The first probability-weighted moment is \(b_1 := E_{\text{loo}}\left[X\,F_{\text{loo}}(X')\right]\). With this, the nonnegative CRPS under PSIS-LOO-CV is

\[\operatorname{CRPS}_{\text{loo}}(F, y) = E_{\text{loo}}\left[\,|X-y|\,\right] + E_{\text{loo}}[X] - 2\,b_1.\]

For the scale term for the SCRPS, we use the PSIS-LOO-CV weighted Gini mean difference given by \(\Delta_{\text{loo}} := E_{\text{loo}}\left[\,|X - X'|\,\right]\). This admits the PWM representation given by

\[\Delta_{\text{loo}} = 2\,E_{\text{loo}}\left[\,X\,\left(2F_{\text{loo}}(X') - 1\right)\,\right].\]

A finite-sample weighted order-statistic version of this is used in the function and is given by

\[\Delta_{\text{loo}} = 2 \sum_{i=1}^S w_{(i)}\, x_{(i)} \left\{\,2 F^-_{(i)} + w_{(i)} - 1\,\right\},\]

where \(x_{(i)}\) are the values sorted increasingly, \(w_{(i)}\) are the corresponding normalized weights, and \(F^-_{(i)} = \sum_{j<i} w_{(j)}\).

The locally scale-invariant score returned for kind="scrps" is

\[S_{\text{SCRPS}}(F, y) = -\frac{E_{\text{loo}}\left[\,|X-y|\,\right]}{\Delta_{\text{loo}}} - \frac{1}{2}\log \Delta_{\text{loo}}.\]

When PSIS weights are highly variable (large Pareto \(k\)), Monte-Carlo noise can increase. This function surfaces PSIS-LOO-CV diagnostics via pareto_k and warns when tail behavior suggests unreliability.

References

[1]

Bolin, D., & Wallin, J. (2023). Local scale invariance and robustness of proper scoring rules. Statistical Science, 38(1), 140–159. https://doi.org/10.1214/22-STS864 arXiv preprint https://arxiv.org/abs/1912.05642

[2]

Gneiting, T., & Raftery, A. E. (2007). Strictly Proper Scoring Rules, Prediction, and Estimation. Journal of the American Statistical Association, 102(477), 359–378. https://doi.org/10.1198/016214506000001437

[3]

Taillardat M, Mestre O, Zamo M, Naveau P (2016). Calibrated ensemble forecasts using quantile regression forests and ensemble model output statistics. Mon Weather Rev 144(6):2375–2393. https://doi.org/10.1175/MWR-D-15-0260.1

[4]

Vehtari, A., Gelman, A., & Gabry, J. (2017). Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing, 27(5), 1413–1432. https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544

[5]

Vehtari, A., et al. (2024). Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72). https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Compute scores and return the mean and standard error:

In [1]: from arviz_stats import loo_score
   ...: from arviz_base import load_arviz_data
   ...: dt = load_arviz_data("centered_eight")
   ...: loo_score(dt, kind="crps")
   ...: 
Out[1]: CRPS(mean=-6.3, se=1.5)
In [2]: loo_score(dt, kind="scrps")
Out[2]: SCRPS(mean=-2.3, se=0.095)

We can also pass previously computed PSIS-LOO weights and return the pointwise values:

In [3]: from arviz_stats import loo
   ...: loo_data = loo(dt, pointwise=True)
   ...: loo_score(dt, kind="crps",
   ...:           log_weights=loo_data.log_weights,
   ...:           pareto_k=loo_data.pareto_k,
   ...:           pointwise=True)
   ...: 
Out[3]: 
CRPS(mean=-6.3, se=1.5, pointwise=<xarray.DataArray 'obs' (school: 8)> Size: 64B
array([-16.15944636,  -3.22980568,  -5.35164536,  -3.18563767,
        -3.69345458,  -3.35402187,  -9.34088256,  -5.698008  ])
Coordinates:
  * school   (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon')