arviz.summary

arviz.summary(data, var_names=None, fmt='wide', round_to=2, include_circ=None, stat_funcs=None, extend=True, credible_interval=0.94, order='C')[source]

Create a data frame with summary statistics.

Parameters:
data : obj

Any object that can be converted to an az.InferenceData object Refer to documentation of az.convert_to_dataset for details

var_names : list

Names of variables to include in summary

include_circ : bool

Whether to include circular statistics

fmt : {‘wide’, ‘long’, ‘xarray’}

Return format is either pandas.DataFrame {‘wide’, ‘long’} or xarray.Dataset {‘xarray’}.

round_to : int

Number of decimals used to round results. Defaults to 2.

stat_funcs : dict

A list of functions or a dict of functions with function names as keys used to calculate statistics. By default, the mean, standard deviation, simulation standard error, and highest posterior density intervals are included.

The functions will be given one argument, the samples for a variable as an nD array, The functions should be in the style of a ufunc and return a single number. For example, np.mean, or scipy.stats.var would both work.

extend : boolean

If True, use the statistics returned by stat_funcs in addition to, rather than in place of, the default statistics. This is only meaningful when stat_funcs is not None.

credible_interval : float, optional

Credible interval to plot. Defaults to 0.94. This is only meaningful when stat_funcs is None.

order : {“C”, “F”}

If fmt is “wide”, use either C or F unpacking order. Defaults to C.

Returns:
pandas.DataFrame

With summary statistics for each variable. Defaults statistics are: mean, sd, hpd_3%, hpd_97%, mc_error, ess and r_hat. ess and r_hat are only computed for traces with 2 or more chains.

Examples

>>> az.summary(trace, ['mu'])
       mean    sd  mc_error  hpd_3  hpd_97  ess  r_hat
mu[0]  0.10  0.06      0.00  -0.02    0.23  487.0  1.00
mu[1] -0.04  0.06      0.00  -0.17    0.08  379.0  1.00

Other statistics can be calculated by passing a list of functions.

>>> import pandas as pd
>>> def trace_sd(x):
...     return pd.Series(np.std(x, 0), name='sd')
...
>>> def trace_quantiles(x):
...     return pd.DataFrame(pd.quantiles(x, [5, 50, 95]))
...
>>> az.summary(trace, ['mu'], stat_funcs=[trace_sd, trace_quantiles])
         sd     5    50    95
mu[0]  0.06  0.00  0.10  0.21
mu[1]  0.07 -0.16 -0.04  0.06