arviz.psislw

arviz.psislw(log_weights, reff=1.0)[source]

Pareto smoothed importance sampling (PSIS).

Parameters
log_weights: array

Array of size (n_observations, n_samples)

reff: float

relative MCMC efficiency, ess / n

Returns
lw_out: array

Smoothed log weights

kss: array

Pareto tail indices

See also

loo

Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

Notes

If the log_weights input is an DataArray with a dimension named __sample__ (recommended) psislw will interpret this dimension as samples, and all other dimensions as dimensions of the observed data, looping over them to calculate the psislw of each observation. If no __sample__ dimension is present or the input is a numpy array, the last dimension will be interpreted as __sample__.

References

Examples

Get Pareto smoothed importance sampling (PSIS) log weights:

In [1]: import arviz as az
   ...: data = az.load_arviz_data("centered_eight")
   ...: log_likelihood = data.sample_stats.log_likelihood.stack(
   ...:     __sample__=("chain", "draw")
   ...: )
   ...: az.psislw(-log_likelihood, reff=0.8)
   ...: 
Out[1]: 
(<xarray.DataArray 'log_weights' (school: 8, __sample__: 2000)>
 array([[-7.35137283, -6.31188687, -7.41236566, ..., -8.46808714,
         -6.96348725, -7.14365799],
        [-6.42972359, -7.55712511, -7.20082771, ..., -7.72636166,
         -7.07301723, -7.48221468],
        [-7.76454863, -7.76032659, -7.76434977, ..., -7.64950566,
         -6.94343535, -7.54502347],
        ...,
        [-7.3256938 , -7.75981448, -7.70411157, ..., -7.73833836,
         -7.61736749, -7.74297445],
        [-7.1045607 , -6.2092645 , -6.43506702, ..., -6.96679107,
         -8.67364049, -8.42241583],
        [-7.19870332, -7.47992943, -7.48333094, ..., -7.48056398,
         -7.64676005, -7.65892997]])
 Coordinates:
   * school      (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
   * __sample__  (__sample__) MultiIndex
   - chain       (__sample__) int64 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3
   - draw        (__sample__) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499,
 <xarray.DataArray 'pareto_shape' (school: 8)>
 array([0.35683758, 0.32524967, 0.53342172, 0.33519276, 0.25373991,
        0.64466503, 0.71238247, 0.28943932])
 Coordinates:
   * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon')