arviz_stats.loo_expectations

Contents

arviz_stats.loo_expectations#

arviz_stats.loo_expectations(data, var_name=None, log_weights=None, kind='mean', probs=None)[source]#

Compute weighted expectations using the PSIS-LOO-CV method.

The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].

Parameters:
data: DataTree or InferenceData

It should contain the groups posterior_predictive and log_likelihood.

var_name: str, optional

The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.

log_weightsxarray.DataArray or ELPDData, optional

Smoothed log weights. Can be either:

  • A DataArray with the same shape as the log likelihood data

  • An ELPDData object from a previous arviz_stats.loo call.

Defaults to None. If not provided, it will be computed using the PSIS-LOO method.

kind: str, optional

The kind of expectation to compute. Available options are:

  • ‘mean’: the mean of the posterior predictive distribution. Default.

  • ‘median’: the median of the posterior predictive distribution.

  • ‘var’: the variance of the posterior predictive distribution.

  • ‘sd’: the standard deviation of the posterior predictive distribution.

  • ‘quantile’: the quantile of the posterior predictive distribution.

probs: float or list of float, optional

The quantile(s) to compute when kind is ‘quantile’.

Returns:
loo_expecxarray.DataArray

The weighted expectations.

khatxarray.DataArray

Function-specific Pareto k-hat diagnostics for each observation.

References

[1]

Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.

[2]

Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646

Examples

Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics

In [1]: from arviz_stats import loo_expectations
   ...: from arviz_base import load_arviz_data
   ...: dt = load_arviz_data("radon")
   ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75])
   ...: loo_expec
   ...: 
Out[1]: 
<xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB
array([[-0.212921  ,  0.51393151,  0.50130407, ...,  1.00688331,
         1.19475673,  1.220047  ],
       [ 0.82123096,  1.49772976,  1.49586184, ...,  1.96597453,
         2.22125176,  2.23358734]], shape=(2, 919))
Coordinates:
  * quantile  (quantile) float64 16B 0.25 0.75
  * obs_id    (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
In [2]: khat
Out[2]: 
<xarray.DataArray 'y' (obs_id: 919)> Size: 7kB
array([ 4.46700951e-02,  2.63854121e-01, -1.07236853e-02,  3.67186708e-01,
        1.42367028e-01,  9.03724670e-02, -9.64916708e-02, -3.23023070e-02,
       -7.63117649e-04, -4.39632041e-02, -4.39632041e-02, -2.77011228e-02,
       -6.87404075e-02,  7.33058816e-03, -4.68570345e-02,  1.45208213e-01,
       -2.77011228e-02, -5.06999352e-02,  1.37056713e-01,  1.39714839e-01,
       -4.65341418e-02,  1.51855305e-01,  1.94804519e-01, -5.13443296e-02,
        2.47162116e-01,  1.42367028e-01,  2.31482545e-01,  2.44197052e-01,
        2.30050539e-01,  1.19981944e-01,  1.45208213e-01,  1.40019803e-01,
        9.59767451e-02,  1.93899206e-01, -4.91620621e-02,  1.27165696e-01,
        1.51855305e-01,  1.29457903e-01, -7.32512043e-02,  1.15699421e-01,
       -6.87404075e-02,  1.39714839e-01,  1.15699421e-01,  2.18827244e-01,
        1.18644098e-01,  7.32251179e-02, -1.97781752e-02, -7.40230544e-02,
        1.33748127e-02, -4.77234266e-02,  1.15699421e-01, -7.40230544e-02,
       -4.77234266e-02, -6.87404075e-02, -5.13443296e-02, -4.65341418e-02,
        1.54028621e-02,  3.24915289e-01,  2.24899198e-01, -2.68898698e-02,
       -2.69178091e-02, -3.40208521e-02,  1.08271366e-01,  3.75426166e-01,
        3.86313391e-02,  5.97667243e-02,  2.09480407e-01,  1.21369876e-01,
        2.44276637e-02,  1.48939483e-01,  1.71784207e-01,  1.96219227e-01,
        2.61662899e-01,  1.05695073e-01, -7.53712062e-02, -1.75064052e-01,
       -6.89317647e-02, -8.41709220e-02, -2.77060430e-03, -1.48189148e-01,
...
        2.06512895e-01,  2.12507258e-01,  1.40228769e-01,  8.34683490e-03,
        1.16030234e-01,  2.30149220e-01,  2.22596841e-02,  9.44759706e-02,
        1.16030234e-01,  1.80182948e-02,  1.45663627e-02, -1.00858590e-01,
        2.75357811e-03,  7.84219623e-02, -1.00858590e-01,  3.23275565e-02,
        1.24061678e-01,  1.26321601e-02,  3.61749643e-02, -9.25393105e-02,
        5.15885509e-03, -6.05826532e-02,  2.08509036e-01,  6.36732947e-02,
        1.68220083e-01,  1.45663627e-02,  1.24409684e-01, -8.44404703e-02,
        7.33142436e-02,  7.17368725e-02, -6.09886444e-02,  2.09901251e-01,
        2.75357811e-03,  2.66859194e-03,  3.61749643e-02, -1.02255701e-01,
        2.05758883e-01,  7.93527385e-02, -1.14139620e-02, -8.44404703e-02,
        3.99240568e-03,  2.90854855e-02,  3.99240568e-03,  2.96225227e-02,
        1.24409684e-01,  1.82664183e-01,  8.13822914e-02,  1.98159623e-01,
        1.93679498e-01,  2.19122933e-01,  2.41194751e-01,  5.92631532e-02,
        1.10015080e-01,  1.23326731e-02,  8.62210686e-02,  1.25478290e-01,
        1.20398193e-01,  2.25184849e-01,  1.25478290e-01,  2.53770715e-02,
        1.83820998e-02,  7.13161230e-02,  6.05499641e-03,  2.68097760e-01,
        1.69499433e-01,  6.25661205e-02,  1.53513527e-01,  2.24842604e-01,
        8.67214216e-02,  2.82213136e-02,  1.99025978e-01,  1.60546472e-01,
        3.28540934e-02,  3.41122021e-02,  1.49708207e-01,  1.52697372e-01,
        2.55589210e-03,  3.53607817e-02,  5.51508714e-02])
Coordinates:
  * obs_id   (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918