arviz_stats.loo_expectations#
- arviz_stats.loo_expectations(data, var_name=None, log_weights=None, kind='mean', probs=None)[source]#
Compute weighted expectations using the PSIS-LOO-CV method.
The expectations assume that the PSIS approximation is working well. The PSIS-LOO-CV method is described in [1] and [2].
- Parameters:
- data: DataTree or InferenceData
It should contain the groups posterior_predictive and log_likelihood.
- var_name: str, optional
The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation.
- log_weights
xarray.DataArrayorELPDData, optional Smoothed log weights. Can be either:
A DataArray with the same shape as the log likelihood data
An ELPDData object from a previous
arviz_stats.loocall.
Defaults to None. If not provided, it will be computed using the PSIS-LOO method.
- kind: str, optional
The kind of expectation to compute. Available options are:
‘mean’: the mean of the posterior predictive distribution. Default.
‘median’: the median of the posterior predictive distribution.
‘var’: the variance of the posterior predictive distribution.
‘sd’: the standard deviation of the posterior predictive distribution.
‘quantile’: the quantile of the posterior predictive distribution.
- probs: float or list of float, optional
The quantile(s) to compute when kind is ‘quantile’.
- Returns:
- loo_expec
xarray.DataArray The weighted expectations.
- khat
xarray.DataArray Function-specific Pareto k-hat diagnostics for each observation.
- loo_expec
References
[1]Vehtari et al. Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544.
[2]Vehtari et al. Pareto Smoothed Importance Sampling. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646
Examples
Calculate predictive 0.25 and 0.75 quantiles and the function-specific Pareto k-hat diagnostics
In [1]: from arviz_stats import loo_expectations ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("radon") ...: loo_expec, khat = loo_expectations(dt, kind="quantile", probs=[0.25, 0.75]) ...: loo_expec ...: Out[1]: <xarray.DataArray 'y' (quantile: 2, obs_id: 919)> Size: 15kB array([[-0.212921 , 0.51393151, 0.50130407, ..., 1.00688331, 1.19475673, 1.220047 ], [ 0.82123096, 1.49772976, 1.49586184, ..., 1.96597453, 2.22125176, 2.23358734]], shape=(2, 919)) Coordinates: * quantile (quantile) float64 16B 0.25 0.75 * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918
In [2]: khat Out[2]: <xarray.DataArray 'y' (obs_id: 919)> Size: 7kB array([ 4.46700951e-02, 2.63854121e-01, -1.07236853e-02, 3.67186708e-01, 1.42367028e-01, 9.03724670e-02, -9.64916708e-02, -3.23023070e-02, -7.63117649e-04, -4.39632041e-02, -4.39632041e-02, -2.77011228e-02, -6.87404075e-02, 7.33058816e-03, -4.68570345e-02, 1.45208213e-01, -2.77011228e-02, -5.06999352e-02, 1.37056713e-01, 1.39714839e-01, -4.65341418e-02, 1.51855305e-01, 1.94804519e-01, -5.13443296e-02, 2.47162116e-01, 1.42367028e-01, 2.31482545e-01, 2.44197052e-01, 2.30050539e-01, 1.19981944e-01, 1.45208213e-01, 1.40019803e-01, 9.59767451e-02, 1.93899206e-01, -4.91620621e-02, 1.27165696e-01, 1.51855305e-01, 1.29457903e-01, -7.32512043e-02, 1.15699421e-01, -6.87404075e-02, 1.39714839e-01, 1.15699421e-01, 2.18827244e-01, 1.18644098e-01, 7.32251179e-02, -1.97781752e-02, -7.40230544e-02, 1.33748127e-02, -4.77234266e-02, 1.15699421e-01, -7.40230544e-02, -4.77234266e-02, -6.87404075e-02, -5.13443296e-02, -4.65341418e-02, 1.54028621e-02, 3.24915289e-01, 2.24899198e-01, -2.68898698e-02, -2.69178091e-02, -3.40208521e-02, 1.08271366e-01, 3.75426166e-01, 3.86313391e-02, 5.97667243e-02, 2.09480407e-01, 1.21369876e-01, 2.44276637e-02, 1.48939483e-01, 1.71784207e-01, 1.96219227e-01, 2.61662899e-01, 1.05695073e-01, -7.53712062e-02, -1.75064052e-01, -6.89317647e-02, -8.41709220e-02, -2.77060430e-03, -1.48189148e-01, ... 2.06512895e-01, 2.12507258e-01, 1.40228769e-01, 8.34683490e-03, 1.16030234e-01, 2.30149220e-01, 2.22596841e-02, 9.44759706e-02, 1.16030234e-01, 1.80182948e-02, 1.45663627e-02, -1.00858590e-01, 2.75357811e-03, 7.84219623e-02, -1.00858590e-01, 3.23275565e-02, 1.24061678e-01, 1.26321601e-02, 3.61749643e-02, -9.25393105e-02, 5.15885509e-03, -6.05826532e-02, 2.08509036e-01, 6.36732947e-02, 1.68220083e-01, 1.45663627e-02, 1.24409684e-01, -8.44404703e-02, 7.33142436e-02, 7.17368725e-02, -6.09886444e-02, 2.09901251e-01, 2.75357811e-03, 2.66859194e-03, 3.61749643e-02, -1.02255701e-01, 2.05758883e-01, 7.93527385e-02, -1.14139620e-02, -8.44404703e-02, 3.99240568e-03, 2.90854855e-02, 3.99240568e-03, 2.96225227e-02, 1.24409684e-01, 1.82664183e-01, 8.13822914e-02, 1.98159623e-01, 1.93679498e-01, 2.19122933e-01, 2.41194751e-01, 5.92631532e-02, 1.10015080e-01, 1.23326731e-02, 8.62210686e-02, 1.25478290e-01, 1.20398193e-01, 2.25184849e-01, 1.25478290e-01, 2.53770715e-02, 1.83820998e-02, 7.13161230e-02, 6.05499641e-03, 2.68097760e-01, 1.69499433e-01, 6.25661205e-02, 1.53513527e-01, 2.24842604e-01, 8.67214216e-02, 2.82213136e-02, 1.99025978e-01, 1.60546472e-01, 3.28540934e-02, 3.41122021e-02, 1.49708207e-01, 1.52697372e-01, 2.55589210e-03, 3.53607817e-02, 5.51508714e-02]) Coordinates: * obs_id (obs_id) int64 7kB 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918