Source code for arviz_stats.loo.loo_moment_match

"""Compute moment matching for problematic observations in PSIS-LOO-CV."""

import warnings
from collections import namedtuple
from copy import deepcopy

import arviz_base as azb
import numpy as np
import xarray as xr
from arviz_base import dataset_to_dataarray, rcParams
from xarray_einstats.stats import logsumexp

from arviz_stats.loo.helper_loo import (
    _get_log_likelihood_i,
    _get_r_eff,
    _get_r_eff_i,
    _get_weights_and_k_i,
    _prepare_loo_inputs,
    _shift,
    _shift_and_cov,
    _shift_and_scale,
    _warn_pareto_k,
)
from arviz_stats.sampling_diagnostics import ess
from arviz_stats.utils import ELPDData

SplitMomentMatch = namedtuple("SplitMomentMatch", ["lwi", "lwfi", "log_liki", "reff"])
UpdateQuantities = namedtuple("UpdateQuantities", ["lwi", "lwfi", "ki", "kfi", "log_liki"])
LooMomentMatchResult = namedtuple(
    "LooMomentMatchResult",
    ["final_log_liki", "final_lwi", "final_ki", "kfs_i", "reff_i", "n_eff_i", "original_ki", "i"],
)


[docs] def loo_moment_match( data, loo_orig, log_prob_upars_fn, log_lik_i_upars_fn, upars=None, var_name=None, reff=None, max_iters=30, k_threshold=None, split=True, cov=True, pointwise=None, ): r"""Compute moment matching for problematic observations in PSIS-LOO-CV. Adjusts the results of a previously computed Pareto smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV) object by applying a moment matching algorithm to observations with high Pareto k diagnostic values. The moment matching algorithm iteratively adjusts the posterior draws in the unconstrained parameter space to better approximate the leave-one-out posterior. The moment matching algorithm is described in [1]_ and the PSIS-LOO-CV method is described in [2]_ and [3]_. See the EABM chapter on `Moment Matching <https://arviz-devs.github.io/EABM/Chapters/Moment_Matching.html>`_ for more details. Parameters ---------- data : DataTree or InferenceData Input data. It should contain the posterior and the log_likelihood groups. loo_orig : ELPDData An existing ELPDData object from a previous `loo` result. Must contain pointwise Pareto k values (`pointwise=True` must have been used). log_prob_upars_fn : callable Function that computes the log probability density of the full posterior distribution evaluated at unconstrained parameter draws. The function signature is ``log_prob_upars_fn(upars)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws with dimensions ``chain``, ``draw``, and a parameter dimension. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. log_lik_i_upars_fn : callable Function that computes the log-likelihood of a single left-out observation evaluated at unconstrained parameter draws. The function signature is ``log_lik_i_upars_fn(upars, i)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws and ``i`` is the integer index of the left-out observation. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. upars : DataArray, optional Posterior draws transformed to the unconstrained parameter space. Must have ``chain`` and ``draw`` dimensions, plus one additional dimension containing all parameters. Parameter names can be provided as coordinate values on this dimension. If not provided, will attempt to use the ``unconstrained_posterior`` group from the input data if available. var_name : str, optional The name of the variable in log_likelihood group storing the pointwise log likelihood data to use for loo computation. reff: float, optional Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. Computed from trace by default. max_iters : int, default 30 Maximum number of moment matching iterations for each problematic observation. k_threshold : float, optional Threshold value for Pareto k values above which moment matching is applied. Defaults to :math:`\min(1 - 1/\log_{10}(S), 0.7)`, where S is the number of samples. split : bool, default True If True, only transform half of the draws and use multiple importance sampling to combine them with untransformed draws. cov : bool, default True If True, match the covariance structure during the transformation, in addition to the mean and marginal variances. If False, only match the mean and marginal variances. pointwise: bool, optional If True, the pointwise predictive accuracy will be returned. Defaults to ``rcParams["stats.ic_pointwise"]``. Moment matching always requires pointwise data from ``loo_orig``. This argument controls whether the returned object includes pointwise data. Returns ------- ELPDData Object with the following attributes: - **kind**: "loo" - **elpd**: expected log pointwise predictive density - **se**: standard error of the elpd - **p**: effective number of parameters - **n_samples**: number of samples - **n_data_points**: number of data points - **scale**: "log" - **warning**: True if the estimated shape parameter of Pareto distribution is greater than ``good_k``. - **good_k**: For a sample size S, the threshold is computed as ``min(1 - 1/log10(S), 0.7)`` - **elpd_i**: :class:`~xarray.DataArray` with the pointwise predictive accuracy, only if ``pointwise=True``. - **pareto_k**: :class:`~xarray.DataArray` with moment-matched Pareto shape values, only if ``pointwise=True``. - **approx_posterior**: False (not used for standard LOO) - **log_weights**: class:`~xarray.DataArray` with smoothed log weights (updated for successfully moment-matched observations). - **influence_pareto_k**: :class:`~xarray.DataArray` with original (pre-moment-matching) Pareto shape values, only if ``pointwise=True``. - **n_eff_i**: :class:`~xarray.DataArray` with effective sample size per observation, only if ``pointwise=True``. See Also -------- loo : Standard PSIS-LOO-CV. reloo : Exact re-fitting for problematic observations. References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). Implicitly Adaptive Importance Sampling. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ if not isinstance(loo_orig, ELPDData): raise TypeError("loo_orig must be an ELPDData object.") if loo_orig.pareto_k is None or loo_orig.elpd_i is None: raise ValueError( "Moment matching requires pointwise LOO results with Pareto k values. " "Please compute the initial LOO with pointwise=True." ) sample_dims = ["chain", "draw"] if upars is None: if hasattr(data, "unconstrained_posterior"): upars_ds = azb.get_unconstrained_samples(data, return_dataset=True) upars = dataset_to_dataarray( upars_ds, sample_dims=sample_dims, new_dim="unconstrained_parameter" ) else: raise ValueError( "upars must be provided or data must contain an 'unconstrained_posterior' group." ) if not isinstance(upars, xr.DataArray): raise TypeError("upars must be a DataArray.") if not all(dim_name in upars.dims for dim_name in sample_dims): raise ValueError(f"upars must have dimensions {sample_dims}.") param_dim_list = [dim for dim in upars.dims if dim not in sample_dims] if len(param_dim_list) == 0: param_dim_name = "upars_dim" upars = upars.expand_dims(dim={param_dim_name: 1}) elif len(param_dim_list) == 1: param_dim_name = param_dim_list[0] else: raise ValueError("upars must have at most one dimension besides 'chain' and 'draw'.") loo_data = deepcopy(loo_orig) loo_data.method = "loo_moment_match" pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise loo_inputs = _prepare_loo_inputs(data, var_name) log_likelihood = loo_inputs.log_likelihood obs_dims = loo_inputs.obs_dims n_samples = loo_inputs.n_samples var_name = loo_inputs.var_name n_params = upars.sizes[param_dim_name] n_data_points = loo_orig.n_data_points if reff is None: reff = _get_r_eff(data, n_samples) try: orig_log_prob = log_prob_upars_fn(upars) if not isinstance(orig_log_prob, xr.DataArray): raise TypeError("log_prob_upars_fn must return a DataArray.") if not all(dim in orig_log_prob.dims for dim in sample_dims): raise ValueError(f"Original log probability must have dimensions {sample_dims}.") if len(orig_log_prob.dims) != len(sample_dims): raise ValueError( f"Original log probability should only have dimensions {sample_dims}, " f"found {orig_log_prob.dims}" ) except Exception as e: raise ValueError(f"Error executing log_prob_upars_fn: {e}") from e if k_threshold is None: k_threshold = min(1 - 1 / np.log10(n_samples), 0.7) if n_samples > 1 else 0.7 loo_data.influence_pareto_k = loo_data.pareto_k.copy() ks = ( loo_data.pareto_k.stack(__pareto_obs_stacked__=obs_dims) .transpose("__pareto_obs_stacked__") .values ) bad_obs_indices = np.where(ks > k_threshold)[0] if len(bad_obs_indices) == 0: warnings.warn("No Pareto k values exceed the threshold. Returning original LOO data.") if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None loo_data.influence_pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None if hasattr(loo_data, "n_eff_i"): loo_data.n_eff_i = None return loo_data lpd = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples) loo_data.p_loo_i = lpd - loo_data.elpd_i kfs = np.zeros(n_data_points) log_weights = getattr(loo_data, "log_weights", None) r_eff_data = getattr(loo_data, "r_eff", reff) # Moment matching algorithm for i in bad_obs_indices: mm_result = _loo_moment_match_i( i=i, upars=upars, log_likelihood=log_likelihood, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, max_iters=max_iters, k_threshold=k_threshold, split=split, cov=cov, orig_log_prob=orig_log_prob, ks=ks, log_weights=log_weights, pareto_k=loo_data.pareto_k, r_eff=r_eff_data, sample_dims=sample_dims, obs_dims=obs_dims, n_samples=n_samples, n_params=n_params, param_dim_name=param_dim_name, var_name=var_name, ) kfs[i] = mm_result.kfs_i if mm_result.final_ki < mm_result.original_ki: new_elpd_i = logsumexp( mm_result.final_log_liki + mm_result.final_lwi, dims=sample_dims ).item() original_log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims) _update_loo_data_i( loo_data, i, new_elpd_i, mm_result.final_ki, mm_result.final_log_liki, sample_dims, obs_dims, n_samples, mm_result.n_eff_i, original_log_liki, mm_result.final_lwi, suppress_warnings=True, ) else: warnings.warn( f"Observation {i}: Moment matching did not improve k " f"({mm_result.original_ki:.2f} -> {mm_result.final_ki:.2f}). Reverting.", UserWarning, stacklevel=2, ) if hasattr(loo_orig, "p_loo_i") and loo_orig.p_loo_i is not None: if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.p_loo_i[idx_dict] = loo_orig.p_loo_i[idx_dict] final_ks = ( loo_data.pareto_k.stack(__pareto_obs_stacked__=obs_dims) .transpose("__pareto_obs_stacked__") .values ) if np.any(final_ks[bad_obs_indices] > k_threshold): warnings.warn( f"After Moment Matching, {np.sum(final_ks > k_threshold)} observations still have " f"Pareto k > {k_threshold:.2f}.", UserWarning, stacklevel=2, ) if not split and np.any(kfs > k_threshold): warnings.warn( "The accuracy of self-normalized importance sampling may be bad. " "Setting the argument 'split' to 'True' will likely improve accuracy.", UserWarning, stacklevel=2, ) elpd_raw = logsumexp(log_likelihood, dims=sample_dims, b=1 / n_samples).sum().values loo_data.p = elpd_raw - loo_data.elpd if not pointwise: loo_data.elpd_i = None loo_data.pareto_k = None loo_data.influence_pareto_k = None if hasattr(loo_data, "p_loo_i"): loo_data.p_loo_i = None if hasattr(loo_data, "n_eff_i"): loo_data.n_eff_i = None return loo_data
def _split_moment_match( upars, cov, total_shift, total_scaling, total_mapping, i, reff, log_prob_upars_fn, log_lik_i_upars_fn, ): r"""Split moment matching importance sampling for PSIS-LOO-CV. Applies affine transformations based on the total moment matching transformation to half of the posterior draws, leaving the other half unchanged. These approximations to the leave-one-out posterior are then combined using multiple importance sampling. Based on the implicit adaptive importance sampling algorithm of [1]_ and the PSIS-LOO-CV method of [2]_ and [3]_. Parameters ---------- upars : DataArray A DataArray representing the posterior draws of the model parameters in the unconstrained space. Must contain the dimensions `chain` and `draw` and a final dimension representing the different unconstrained parameters. cov : bool Whether to match the full covariance matrix of the samples (True) or just the marginal variances (False). Using the full covariance is more computationally expensive. total_shift : ndarray Vector containing the total shift (translation) applied to the parameters. Shape should match the parameter dimension of ``upars``. total_scaling : ndarray Vector containing the total scaling factors for the marginal variances. Shape should match the parameter dimension of ``upars``. total_mapping : ndarray Square matrix representing the linear transformation applied to the covariance matrix. Shape should be (d, d) where d is the parameter dimension. i : int Index of the specific observation to be left out for computing leave-one-out likelihood. reff : float Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples divided by the number of actual samples. log_prob_upars_fn : callable Function that computes the log probability density of the *full posterior* distribution evaluated at unconstrained parameter draws. The function signature is ``log_prob_upars_fn(upars)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. log_lik_i_upars_fn : callable Function that computes the log-likelihood of the *left-out observation* ``i`` evaluated at unconstrained parameter draws. The function signature is ``log_lik_i_upars_fn(upars, i)`` where ``upars`` is a :class:`~xarray.DataArray` of unconstrained parameter draws and ``i`` is the integer index of the observation. It should return a :class:`~xarray.DataArray` with dimensions ``chain``, ``draw``. Returns ------- SplitMomentMatch A namedtuple containing: - lwi: Updated log importance weights for each sample - lwfi: Updated log importance weights for full distribution - log_liki: Updated log likelihood values for the specific observation - reff: Relative MCMC efficiency (updated based on the split samples) References ---------- .. [1] Paananen, T., Piironen, J., Buerkner, P.-C., Vehtari, A. (2021). *Implicitly Adaptive Importance Sampling*. Statistics and Computing. 31(2) (2021) https://doi.org/10.1007/s11222-020-09982-2 arXiv preprint https://arxiv.org/abs/1906.08850. .. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4 arXiv preprint https://arxiv.org/abs/1507.04544. .. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*. Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html arXiv preprint https://arxiv.org/abs/1507.02646 """ sample_dims = ["chain", "draw"] param_dim = next(dim for dim in upars.dims if dim not in sample_dims) dim = upars.sizes[param_dim] n_chains = upars.sizes["chain"] n_draws = upars.sizes["draw"] n_samples = n_chains * n_draws n_samples_half = n_samples // 2 stack_dims = ["draw", "chain"] upars_stacked = upars.stack(__sample__=stack_dims).transpose("__sample__", param_dim) mean_original = upars_stacked.mean(dim="__sample__") # Forward transformation upars_trans = upars_stacked - mean_original upars_trans = upars_trans * xr.DataArray(total_scaling, dims=param_dim) if cov and dim > 0: upars_trans = xr.DataArray( upars_trans.data @ total_mapping.T, coords=upars_trans.coords, dims=upars_trans.dims, ) # Inverse Transformation upars_trans = upars_trans + (xr.DataArray(total_shift, dims=param_dim) + mean_original) upars_trans_inv = upars_stacked - (xr.DataArray(total_shift, dims=param_dim) + mean_original) if cov and dim > 0: try: inv_mapping_t = np.linalg.inv(total_mapping.T) upars_trans_inv = xr.DataArray( upars_trans_inv.data @ inv_mapping_t, coords=upars_trans_inv.coords, dims=upars_trans_inv.dims, ) except np.linalg.LinAlgError: warnings.warn("Could not invert mapping matrix. Using identity.", UserWarning) upars_trans_inv = upars_trans_inv / xr.DataArray(total_scaling, dims=param_dim) upars_trans_inv = upars_trans_inv + (mean_original - xr.DataArray(total_shift, dims=param_dim)) upars_trans_half_stacked = upars_stacked.copy(deep=True) upars_trans_half_stacked.data[:n_samples_half, :] = upars_trans.data[:n_samples_half, :] upars_trans_half = upars_trans_half_stacked.unstack("__sample__").transpose( *reversed(stack_dims), param_dim ) upars_trans_half_inv_stacked = upars_stacked.copy(deep=True) upars_trans_half_inv_stacked.data[n_samples_half:, :] = upars_trans_inv.data[n_samples_half:, :] upars_trans_half_inv = upars_trans_half_inv_stacked.unstack("__sample__").transpose( *reversed(stack_dims), param_dim ) try: log_prob_half_trans = log_prob_upars_fn(upars_trans_half) if not isinstance(log_prob_half_trans, xr.DataArray): raise TypeError("log_prob_upars_fn must return a DataArray.") if not all(dim in log_prob_half_trans.dims for dim in sample_dims) or len( log_prob_half_trans.dims ) != len(sample_dims): raise ValueError( f"log_prob_upars_fn must return a DataArray with dimensions {sample_dims}, " f"but got {log_prob_half_trans.dims}" ) log_prob_half_trans_inv = log_prob_upars_fn(upars_trans_half_inv) if not isinstance(log_prob_half_trans_inv, xr.DataArray): raise TypeError("log_prob_upars_fn must return a DataArray.") if not all(dim in log_prob_half_trans_inv.dims for dim in sample_dims) or len( log_prob_half_trans_inv.dims ) != len(sample_dims): raise ValueError( f"log_prob_upars_fn must return a DataArray with dimensions {sample_dims}, " f"but got {log_prob_half_trans_inv.dims}" ) except Exception as e: raise ValueError( f"Could not compute log probabilities for transformed parameters: {e}" ) from e try: log_liki_half = log_lik_i_upars_fn(upars_trans_half, i) if not all(dim in log_liki_half.dims for dim in sample_dims) or len( log_liki_half.dims ) != len(sample_dims): raise ValueError( f"log_lik_i_upars_fn must return a DataArray with dimensions {sample_dims}" ) if ( log_liki_half.sizes["chain"] != upars.sizes["chain"] or log_liki_half.sizes["draw"] != upars.sizes["draw"] ): raise ValueError( "log_lik_i_upars_fn output shape does not match input sample dimensions" ) except Exception as e: raise ValueError(f"Could not compute log likelihood for observation {i}: {e}") from e log_jacobian_det = 0.0 if dim > 0: log_jacobian_det = -np.sum(np.log(total_scaling)) try: log_jacobian_det -= np.log(np.linalg.det(total_mapping)) except np.linalg.LinAlgError: log_jacobian_det -= np.inf log_prob_half_trans_inv_adj = log_prob_half_trans_inv + log_jacobian_det # Multiple importance sampling use_forward_log_prob = log_prob_half_trans > log_prob_half_trans_inv_adj raw_log_weights_half = -log_liki_half + log_prob_half_trans log_sum_terms = xr.where( use_forward_log_prob, log_prob_half_trans + xr.ufuncs.log1p(np.exp(log_prob_half_trans_inv_adj - log_prob_half_trans)), log_prob_half_trans_inv_adj + xr.ufuncs.log1p(np.exp(log_prob_half_trans - log_prob_half_trans_inv_adj)), ) raw_log_weights_half -= log_sum_terms raw_log_weights_half = xr.where(np.isnan(raw_log_weights_half), -np.inf, raw_log_weights_half) raw_log_weights_half = xr.where( np.isposinf(raw_log_weights_half), -np.inf, raw_log_weights_half ) # PSIS smoothing for half posterior lwi_psis_da, _ = _wrap__psislw(raw_log_weights_half, sample_dims, reff) lr_full = lwi_psis_da + log_liki_half lr_full = xr.where(np.isnan(lr_full) | (np.isinf(lr_full) & (lr_full > 0)), -np.inf, lr_full) # PSIS smoothing for full posterior lwfi_psis_da, _ = _wrap__psislw(lr_full, sample_dims, reff) n_chains = upars.sizes["chain"] if n_chains == 1: reff_updated = reff else: log_liki_half_1 = log_liki_half.isel( chain=slice(None), draw=slice(0, n_samples_half // n_chains) ) log_liki_half_2 = log_liki_half.isel( chain=slice(None), draw=slice(n_samples_half // n_chains, None) ) liki_half_1 = np.exp(log_liki_half_1) liki_half_2 = np.exp(log_liki_half_2) ess_1 = liki_half_1.azstats.ess(method="mean") ess_2 = liki_half_2.azstats.ess(method="mean") ess_1_value = ess_1.values if hasattr(ess_1, "values") else ess_1 ess_2_value = ess_2.values if hasattr(ess_2, "values") else ess_2 n_samples_1 = log_liki_half_1.size n_samples_2 = log_liki_half_2.size r_eff_1 = ess_1_value / n_samples_1 r_eff_2 = ess_2_value / n_samples_2 reff_updated = min(r_eff_1, r_eff_2) return SplitMomentMatch( lwi=lwi_psis_da, lwfi=lwfi_psis_da, log_liki=log_liki_half, reff=reff_updated, ) def _loo_moment_match_i( i, upars, log_likelihood, log_prob_upars_fn, log_lik_i_upars_fn, max_iters, k_threshold, split, cov, orig_log_prob, ks, log_weights, pareto_k, r_eff, sample_dims, obs_dims, n_samples, n_params, param_dim_name, var_name, ): """Compute moment matching for a single observation.""" n_chains = upars.sizes["chain"] n_draws = upars.sizes["draw"] log_liki = _get_log_likelihood_i(log_likelihood, i, obs_dims).squeeze(drop=True) if isinstance(r_eff, xr.DataArray): reff_i = _get_r_eff_i(r_eff, i, obs_dims) elif r_eff is not None: reff_i = r_eff else: liki = np.exp(log_liki) liki_reshaped = liki.values.reshape(n_chains, n_draws).T ess_val = ess(liki_reshaped, method="mean").item() reff_i = ess_val / n_samples if n_samples > 0 else 1.0 original_ki = ks[i] if log_weights is not None: log_weights_i, ki = _get_weights_and_k_i( log_weights=log_weights, pareto_k=pareto_k, i=i, obs_dims=obs_dims, sample_dims=sample_dims, data=log_likelihood, n_samples=n_samples, reff=reff_i, log_lik_i=log_liki, var_name=var_name, ) lwi = log_weights_i.squeeze(drop=True).transpose(*sample_dims).astype(np.float64) else: log_ratio_i_init = -log_liki lwi, ki = _wrap__psislw(log_ratio_i_init, sample_dims, reff_i) lwfi = xr.full_like(lwi, -np.log(n_samples)) upars_i = upars.copy(deep=True) total_shift = np.zeros(upars_i.sizes[param_dim_name]) total_scaling = np.ones(upars_i.sizes[param_dim_name]) total_mapping = np.eye(upars_i.sizes[param_dim_name]) iterind = 1 transformations_applied = False kfs_i = 0 while iterind <= max_iters and ki > k_threshold: if iterind == max_iters: warnings.warn( f"Maximum number of moment matching iterations ({max_iters}) reached " f"for observation {i}. Final Pareto k is {ki:.2f}.", UserWarning, stacklevel=2, ) break # Try Mean Shift try: shift_res = _shift(upars_i, lwi) quantities_i = _update_quantities_i( shift_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi lwfi = quantities_i.lwfi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = shift_res.upars total_shift = total_shift + shift_res.shift transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during mean shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Scale Shift try: scale_res = _shift_and_scale(upars_i, lwi) quantities_i = _update_quantities_i( scale_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi lwfi = quantities_i.lwfi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = scale_res.upars total_shift = total_shift + scale_res.shift total_scaling = total_scaling * scale_res.scaling transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during scale shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break # Try Covariance Shift if cov and n_samples >= 10 * n_params: try: cov_res = _shift_and_cov(upars_i, lwi) quantities_i = _update_quantities_i( cov_res.upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ) if quantities_i.ki < ki: ki = quantities_i.ki lwi = quantities_i.lwi lwfi = quantities_i.lwfi log_liki = quantities_i.log_liki kfs_i = quantities_i.kfi upars_i = cov_res.upars total_shift = total_shift + cov_res.shift total_mapping = cov_res.mapping @ total_mapping transformations_applied = True iterind += 1 continue # Restart, try mean shift again except RuntimeError as e: warnings.warn( f"Error during covariance shift calculation for observation {i}: {e}. " "Stopping moment matching for this observation.", UserWarning, stacklevel=2, ) break break if split and transformations_applied: try: split_res = _split_moment_match( upars=upars, cov=cov, total_shift=total_shift, total_scaling=total_scaling, total_mapping=total_mapping, i=i, reff=reff_i, log_prob_upars_fn=log_prob_upars_fn, log_lik_i_upars_fn=log_lik_i_upars_fn, ) final_log_liki = split_res.log_liki final_lwi = split_res.lwi final_lwfi = split_res.lwfi final_ki = ki reff_i = split_res.reff except RuntimeError as e: warnings.warn( f"Error during split moment matching for observation {i}: {e}. " "Using non-split transformation result.", UserWarning, stacklevel=2, ) final_log_liki = log_liki final_lwi = lwi final_lwfi = lwfi final_ki = ki else: final_log_liki = log_liki final_lwi = lwi final_lwfi = lwfi final_ki = ki liki_final = np.exp(final_log_liki) liki_final_reshaped = liki_final.values.reshape(n_chains, n_draws).T ess_val_final = ess(liki_final_reshaped, method="mean").item() reff_i = ess_val_final / n_samples if n_samples > 0 else 1.0 lwi_vals = final_lwi.values.flatten() lwfi_vals = final_lwfi.values.flatten() n_eff_loo = 1.0 / np.sum(np.exp(2 * lwi_vals)) n_eff_full = 1.0 / np.sum(np.exp(2 * lwfi_vals)) n_eff_i = min(n_eff_loo, n_eff_full) * reff_i return LooMomentMatchResult( final_log_liki=final_log_liki, final_lwi=final_lwi, final_ki=final_ki, kfs_i=kfs_i, reff_i=reff_i, n_eff_i=n_eff_i, original_ki=original_ki, i=i, ) def _update_loo_data_i( loo_data, i, new_elpd_i, new_pareto_k, log_liki, sample_dims, obs_dims, n_samples, n_eff_i=None, original_log_liki=None, log_weights_i=None, suppress_warnings=False, ): """Update the ELPDData object for a single observation.""" if loo_data.elpd_i is None or loo_data.pareto_k is None: raise ValueError("loo_data must contain pointwise elpd_i and pareto_k values.") lpd_i_log_lik = original_log_liki if original_log_liki is not None else log_liki lpd_i = logsumexp(lpd_i_log_lik, dims=sample_dims, b=1 / n_samples).item() p_loo_i = lpd_i - new_elpd_i if len(obs_dims) == 1: idx_dict = {obs_dims[0]: i} else: coords = np.unravel_index(i, tuple(loo_data.elpd_i.sizes[d] for d in obs_dims)) idx_dict = dict(zip(obs_dims, coords)) loo_data.elpd_i[idx_dict] = new_elpd_i loo_data.pareto_k[idx_dict] = new_pareto_k if getattr(loo_data, "p_loo_i", None) is None: loo_data.p_loo_i = xr.full_like(loo_data.elpd_i, np.nan) loo_data.p_loo_i[idx_dict] = p_loo_i if n_eff_i is not None: if getattr(loo_data, "n_eff_i", None) is None: loo_data.n_eff_i = xr.full_like(loo_data.elpd_i, np.nan) loo_data.n_eff_i[idx_dict] = n_eff_i if log_weights_i is not None: loo_data.log_weights[idx_dict] = log_weights_i loo_data.elpd = np.nansum(loo_data.elpd_i.values) loo_data.se = np.sqrt(loo_data.n_data_points * np.nanvar(loo_data.elpd_i.values, ddof=1)) loo_data.warning, loo_data.good_k = _warn_pareto_k( loo_data.pareto_k.values[~np.isnan(loo_data.pareto_k.values)], loo_data.n_samples, suppress=suppress_warnings, ) def _update_quantities_i( upars, i, orig_log_prob, log_prob_upars_fn, log_lik_i_upars_fn, reff_i, sample_dims, ): """Update the moment matching quantities for a single observation.""" log_prob_new = log_prob_upars_fn(upars) log_liki_new = log_lik_i_upars_fn(upars, i) log_ratio_i = -log_liki_new + log_prob_new - orig_log_prob log_ratio_i = xr.where(np.isnan(log_ratio_i), -np.inf, log_ratio_i) lwi_new, ki_new = _wrap__psislw(log_ratio_i, sample_dims, reff_i) log_ratio_full = log_prob_new - orig_log_prob log_ratio_full = xr.where(np.isnan(log_ratio_full), -np.inf, log_ratio_full) lwfi_new, kfi_new = _wrap__psislw(log_ratio_full, sample_dims, reff_i) return UpdateQuantities( lwi=lwi_new, lwfi=lwfi_new, ki=ki_new, kfi=kfi_new, log_liki=log_liki_new, ) def _wrap__psislw(log_weights, sample_dims, r_eff): """Apply PSIS smoothing over sample dimensions.""" if not isinstance(log_weights, xr.DataArray): raise TypeError("log_weights must be an xarray.DataArray") missing_dims = [dim for dim in sample_dims if dim not in log_weights.dims] if missing_dims: raise ValueError( f"All sample dimensions must be present in the input; missing {missing_dims}." ) other_dims = [dim for dim in log_weights.dims if dim not in sample_dims] if other_dims: raise ValueError( "_wrap__psislw expects `log_weights` to include only sample dimensions; " f"found extra dims {other_dims}." ) stacked = log_weights.stack(__sample__=sample_dims) stacked_for_psis = -stacked try: lw_stacked, k = stacked_for_psis.azstats.psislw(dim="__sample__", r_eff=r_eff) except ValueError as err: err_message = str(err) fallback_errors = ("All tail values are the same", "n_draws_tail must be at least 5") if not any(msg in err_message for msg in fallback_errors): raise log_norm = logsumexp(stacked, dims="__sample__") lw_stacked = stacked - log_norm k = np.inf lw = lw_stacked.unstack("__sample__").transpose(*log_weights.dims) if isinstance(k, xr.DataArray): if k.dims: raise ValueError("Unexpected dimensions on Pareto k output; expected scalar result.") k_val = k.item() elif isinstance(k, np.ndarray): if k.ndim != 0: raise ValueError("Unexpected array shape for Pareto k; expected scalar result.") k_val = k.item() else: try: k_val = k except (TypeError, ValueError) as exc: raise TypeError("Unable to convert PSIS tail index to float") from exc return lw, k_val