Source code for arviz_stats.loo.loo_score

"""Continuously ranked probability scores with PSIS-LOO-CV weights."""

from collections import namedtuple

import numpy as np
import xarray as xr
from arviz_base import convert_to_datatree, extract
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import (
    _get_r_eff,
    _prepare_loo_inputs,
    _validate_crps_input,
    _warn_pareto_k,
)
from arviz_stats.utils import round_num


[docs] def loo_score( data, var_name=None, log_weights=None, pareto_k=None, kind="crps", pointwise=False, round_to="2g", ): r"""Compute PWM-based CRPS/SCRPS with PSIS-LOO-CV weights. Implements the probability-weighted-moment (PWM) identity for the continuous ranked probability score (CRPS) with Pareto-smoothed importance sampling leave-one-out (PSIS-LOO-CV) weights, but returns its negative as a maximization score (larger is better). This assumes that the PSIS-LOO-CV approximation is working well. Specifically, the PWM identity used here is .. math:: \operatorname{CRPS}_{\text{loo}}(F, y) = E_{\text{loo}}\left[|X - y|\right] + E_{\text{loo}}[X] - 2\cdot E_{\text{loo}} \left[X\,F_{\text{loo}}(X') \right]. The PWM identity is described in [3]_, traditional CRPS and SCRPS are described in [1]_ and [2]_, and the PSIS-LOO-CV method is described in [4]_ and [5]_. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the ``posterior_predictive``, ``observed_data`` and ``log_likelihood`` groups. var_name : str, optional The name of the variable in the log_likelihood group to use. If None, the first variable in ``observed_data`` is used and assumed to match ``log_likelihood`` and ``posterior_predictive`` names. log_weights : DataArray, optional Smoothed log weights for PSIS-LOO-CV. Must have the same shape as the log-likelihood data. Defaults to None. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with ``pareto_k`` or both must be None. pareto_k : DataArray, optional Pareto tail indices corresponding to the PSIS smoothing. Same shape as the log-likelihood data. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with ``log_weights`` or both must be None. kind : str, default "crps" The kind of score to compute. Available options are: - 'crps': continuous ranked probability score. Default. - 'scrps': scale-invariant continuous ranked probability score. pointwise : bool, default False If True, include per-observation score values in the return object. round_to : int or str, default "2g" If integer, number of decimal places to round the result. If string of the form ``"2g"``, number of significant digits to round the result. Use None to return raw numbers. Returns ------- namedtuple If ``pointwise`` is False (default), a namedtuple named ``CRPS`` or ``SCRPS`` with fields ``mean`` and ``se``. If ``pointwise`` is True, the namedtuple also includes a ``pointwise`` field with per-observation values. Examples -------- Compute scores and return the mean and standard error: .. ipython:: :okwarning: In [1]: from arviz_stats import loo_score ...: from arviz_base import load_arviz_data ...: dt = load_arviz_data("centered_eight") ...: loo_score(dt, kind="crps") .. ipython:: :okwarning: In [2]: loo_score(dt, kind="scrps") We can also pass previously computed PSIS-LOO weights and return the pointwise values: .. ipython:: :okwarning: In [3]: from arviz_stats import loo ...: loo_data = loo(dt, pointwise=True) ...: loo_score(dt, kind="crps", ...: log_weights=loo_data.log_weights, ...: pareto_k=loo_data.pareto_k, ...: pointwise=True) Notes ----- For a single observation with posterior-predictive draws :math:`x_1, \ldots, x_S` and PSIS-LOO-CV weights :math:`w_i \propto \exp(\ell_i)` normalized so that :math:`\sum_{i=1}^S w_i = 1`, define the PSIS-LOO-CV expectation and the left-continuous weighted CDF as .. math:: E_{\text{loo}}[g(X)] := \sum_{i=1}^S w_i\, g(x_i), \quad F_{\text{loo}}(x') := \sum_{i: x_i < x} w_i. The first probability-weighted moment is :math:`b_1 := E_{\text{loo}}\left[X\,F_{\text{loo}}(X')\right]`. With this, the nonnegative CRPS under PSIS-LOO-CV is .. math:: \operatorname{CRPS}_{\text{loo}}(F, y) = E_{\text{loo}}\left[\,|X-y|\,\right] + E_{\text{loo}}[X] - 2\,b_1. For the scale term for the SCRPS, we use the PSIS-LOO-CV weighted Gini mean difference given by :math:`\Delta_{\text{loo}} := E_{\text{loo}}\left[\,|X - X'|\,\right]`. This admits the PWM representation given by .. math:: \Delta_{\text{loo}} = 2\,E_{\text{loo}}\left[\,X\,\left(2F_{\text{loo}}(X') - 1\right)\,\right]. A finite-sample weighted order-statistic version of this is used in the function and is given by .. math:: \Delta_{\text{loo}} = 2 \sum_{i=1}^S w_{(i)}\, x_{(i)} \left\{\,2 F^-_{(i)} + w_{(i)} - 1\,\right\}, where :math:`x_{(i)}` are the values sorted increasingly, :math:`w_{(i)}` are the corresponding normalized weights, and :math:`F^-_{(i)} = \sum_{j<i} w_{(j)}`. The locally scale-invariant score returned for ``kind="scrps"`` is .. math:: S_{\text{SCRPS}}(F, y) = -\frac{E_{\text{loo}}\left[\,|X-y|\,\right]}{\Delta_{\text{loo}}} - \frac{1}{2}\log \Delta_{\text{loo}}. When PSIS weights are highly variable (large Pareto :math:`k`), Monte-Carlo noise can increase. This function surfaces PSIS-LOO-CV diagnostics via ``pareto_k`` and warns when tail behavior suggests unreliability. References ---------- .. [1] Bolin, D., & Wallin, J. (2023). *Local scale invariance and robustness of proper scoring rules*. Statistical Science, 38(1), 140–159. https://doi.org/10.1214/22-STS864 arXiv preprint https://arxiv.org/abs/1912.05642 .. [2] Gneiting, T., & Raftery, A. E. (2007). *Strictly Proper Scoring Rules, Prediction, and Estimation*. Journal of the American Statistical Association, 102(477), 359–378. https://doi.org/10.1198/016214506000001437 .. [3] Taillardat M, Mestre O, Zamo M, Naveau P (2016). *Calibrated ensemble forecasts using quantile regression forests and ensemble model output statistics*. Mon Weather Rev 144(6):2375–2393. https://doi.org/10.1175/MWR-D-15-0260.1 .. [4] Vehtari, A., Gelman, A., & Gabry, J. (2017). *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing, 27(5), 1413–1432. https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544 .. [5] Vehtari, A., et al. (2024). *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72). https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if kind not in {"crps", "scrps"}: raise ValueError(f"kind must be either 'crps' or 'scrps'. Got {kind}") data = convert_to_datatree(data) loo_inputs = _prepare_loo_inputs(data, var_name) var_name = loo_inputs.var_name log_likelihood = loo_inputs.log_likelihood y_pred = extract(data, group="posterior_predictive", var_names=var_name, combined=False) y_obs = extract(data, group="observed_data", var_names=var_name, combined=False) n_samples = loo_inputs.n_samples sample_dims = loo_inputs.sample_dims obs_dims = loo_inputs.obs_dims r_eff = _get_r_eff(data, n_samples) _validate_crps_input(y_pred, y_obs, log_likelihood, sample_dims=sample_dims, obs_dims=obs_dims) if (log_weights is None) != (pareto_k is None): raise ValueError( "Both log_weights and pareto_k must be provided together or both must be None. " "Only one was provided." ) if log_weights is None and pareto_k is None: log_weights_da, pareto_k = log_likelihood.azstats.psislw(r_eff=r_eff, dim=sample_dims) else: log_weights_da = log_weights abs_error = np.abs(y_pred - y_obs) loo_weighted_abs_error = _loo_weighted_mean(abs_error, log_weights_da, sample_dims) loo_weighted_mean_prediction = _loo_weighted_mean(y_pred, log_weights_da, sample_dims) pwm_first_moment_b1 = _apply_pointwise_weighted_statistic( y_pred, log_weights_da, sample_dims, _compute_pwm_first_moment_b1 ) crps_pointwise = ( loo_weighted_abs_error + loo_weighted_mean_prediction - 2.0 * pwm_first_moment_b1 ) if kind == "crps": pointwise_scores = -crps_pointwise khat_da = pareto_k else: gini_mean_difference = _apply_pointwise_weighted_statistic( y_pred, log_weights_da, sample_dims, _compute_weighted_gini_mean_difference ) pointwise_scores = -(loo_weighted_abs_error / gini_mean_difference) - 0.5 * np.log( gini_mean_difference ) khat_da = pareto_k _warn_pareto_k(khat_da, n_samples) n_pts = int(np.prod([pointwise_scores.sizes[d] for d in pointwise_scores.dims])) mean = pointwise_scores.mean().values.item() se = (pointwise_scores.std(ddof=0).values / (n_pts**0.5)).item() name = "SCRPS" if kind == "scrps" else "CRPS" if pointwise: return namedtuple(name, ["mean", "se", "pointwise"])( round_num(mean, round_to), round_num(se, round_to), pointwise_scores, ) return namedtuple(name, ["mean", "se"])( round_num(mean, round_to), round_num(se, round_to), )
def _compute_pwm_first_moment_b1(values_sorted, weights): """Compute first PWM using a left-continuous weighted CDF.""" values_sorted, weights_sorted = _sort_values_and_normalize_weights(values_sorted, weights) cumulative_weights = np.cumsum(weights_sorted) f_minus = cumulative_weights - weights_sorted return np.sum(weights_sorted * values_sorted * f_minus).item() def _compute_weighted_gini_mean_difference(values, weights): """Compute PSIS-LOO-CV weighted Gini mean difference.""" values_sorted, weights_sorted = _sort_values_and_normalize_weights(values, weights) cumulative_weights = np.cumsum(weights_sorted) cumulative_before = cumulative_weights - weights_sorted bracket = 2.0 * cumulative_before + weights_sorted - 1.0 return (2.0 * np.sum(weights_sorted * values_sorted * bracket)).item() def _loo_weighted_mean(values, log_weights, dim): """Compute PSIS-LOO-CV weighted mean.""" log_num = logsumexp(log_weights, dims=dim, b=values) log_den = logsumexp(log_weights, dims=dim) return np.exp(log_num - log_den) def _apply_pointwise_weighted_statistic(x, log_weights, sample_dims, stat_func): """Apply a weighted statistic over sample dims.""" max_logw = log_weights.max(dim=sample_dims) weights = np.exp(log_weights - max_logw) stacked = "__sample__" xs = x.stack({stacked: sample_dims}) ws = weights.stack({stacked: sample_dims}) return xr.apply_ufunc( stat_func, xs, ws, input_core_dims=[[stacked], [stacked]], output_core_dims=[[]], vectorize=True, output_dtypes=[float], ) def _sort_values_and_normalize_weights(values, weights): """Sort values by ascending order and normalize weights.""" idx = np.argsort(values, kind="mergesort") values_sorted = values[idx] weights_sorted = weights[idx] weights_sorted = weights_sorted / np.sum(weights_sorted) return values_sorted, weights_sorted