arviz.summary

arviz.summary(data, var_names: Optional[List[str]] = None, filter_vars=None, group=None, fmt: Literal['wide', 'long', 'xarray'] = 'wide', kind: Literal['all', 'stats', 'diagnostics'] = 'all', round_to=None, circ_var_names=None, stat_funcs=None, extend=True, hdi_prob=None, skipna=False, labeller=None, coords=None, index_origin=None, order=None) Union[pandas.core.frame.DataFrame, xarray.core.dataset.Dataset][source]

Create a data frame with summary statistics.

Parameters
data: obj

Any object that can be converted to an arviz.InferenceData object Refer to documentation of arviz.convert_to_dataset() for details

var_names: list

Names of variables to include in summary. Prefix the variables by ~ when you want to exclude them from the summary: [“~beta”] instead of [“beta”] (see examples below).

filter_vars: {None, “like”, “regex”}, optional, default=None

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter.

coords: Dict[str, List[Any]], optional

Coordinate subset for which to calculate the summary.

group: str

Select a group for summary. Defaults to “posterior”, “prior” or first group in that order, depending what groups exists.

fmt: {‘wide’, ‘long’, ‘xarray’}

Return format is either pandas.DataFrame {‘wide’, ‘long’} or xarray.Dataset {‘xarray’}.

kind: {‘all’, ‘stats’, ‘diagnostics’}

Whether to include the stats: mean, sd, hdi_3%, hdi_97%, or the diagnostics: mcse_mean, mcse_sd, ess_bulk, ess_tail, and r_hat. Default to include all of them.

round_to: int

Number of decimals used to round results. Defaults to 2. Use “none” to return raw numbers.

circ_var_names: list

A list of circular variables to compute circular stats for

stat_funcs: dict

A list of functions or a dict of functions with function names as keys used to calculate statistics. By default, the mean, standard deviation, simulation standard error, and highest posterior density intervals are included.

The functions will be given one argument, the samples for a variable as an nD array, The functions should be in the style of a ufunc and return a single number. For example, numpy.mean(), or scipy.stats.var would both work.

extend: boolean

If True, use the statistics returned by stat_funcs in addition to, rather than in place of, the default statistics. This is only meaningful when stat_funcs is not None.

hdi_prob: float, optional

Highest density interval to compute. Defaults to 0.94. This is only meaningful when stat_funcs is None.

skipna: bool

If true ignores nan values when computing the summary statistics, it does not affect the behaviour of the functions passed to stat_funcs. Defaults to false.

labellerlabeller instance, optional

Class providing the method make_label_flat to generate the labels in the plot titles. For more details on labeller usage see Label guide

credible_interval: float, optional

deprecated: Please see hdi_prob

order

deprecated: order is now ignored.

index_origin

deprecated: index_origin is now ignored, modify the coordinate values to change the value used in summary.

Returns
pandas.DataFrame or xarray.Dataset

Return type dicated by fmt argument. Return value will contain summary statistics for each variable. Default statistics are: mean, sd, hdi_3%, hdi_97%, mcse_mean, mcse_sd, ess_bulk, ess_tail, and r_hat. r_hat is only computed for traces with 2 or more chains.

See also

waic

Compute the widely applicable information criterion.

loo

Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).

ess

Calculate estimate of the effective sample size (ess).

rhat

Compute estimate of rank normalized splitR-hat for a set of traces.

mcse

Calculate Markov Chain Standard Error statistic.

Examples

In [1]: import arviz as az
   ...: data = az.load_arviz_data("centered_eight")
   ...: az.summary(data, var_names=["mu", "tau"])
   ...: 
Out[1]: 
      mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
mu   4.093  3.372  -2.118   10.403  ...    0.152     250.0     643.0   1.03
tau  4.089  3.001   0.569    9.386  ...    0.178      79.0      54.0   1.07

[2 rows x 9 columns]

You can use filter_vars to select variables without having to specify all the exact names. Use filter_vars="like" to select based on partial naming:

In [2]: az.summary(data, var_names=["the"], filter_vars="like")
Out[2]: 
                          mean     sd  hdi_3%  ...  ess_bulk  ess_tail  r_hat
theta[Choate]            6.026  5.782  -3.707  ...     348.0     743.0   1.02
theta[Deerfield]         4.724  4.736  -4.039  ...     471.0    1018.0   1.02
theta[Phillips Andover]  3.576  5.559  -6.779  ...     463.0     674.0   1.01
theta[Phillips Exeter]   4.478  4.939  -5.528  ...     503.0     666.0   1.01
theta[Hotchkiss]         3.064  4.642  -5.972  ...     380.0     833.0   1.02
theta[Lawrenceville]     3.821  4.979  -5.507  ...     516.0    1104.0   1.02
theta[St. Paul's]        6.250  5.436  -3.412  ...     402.0    1026.0   1.02
theta[Mt. Hermon]        4.544  5.521  -5.665  ...     449.0    1084.0   1.01

[8 rows x 9 columns]

Use filter_vars="regex" to select based on regular expressions, and prefix the variables you want to exclude by ~. Here, we exclude from the summary all the variables starting with the letter t:

In [3]: az.summary(data, var_names=["~^t"], filter_vars="regex")
Out[3]: 
     mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
mu  4.093  3.372  -2.118   10.403  ...    0.152     250.0     643.0   1.03

[1 rows x 9 columns]

Other statistics can be calculated by passing a list of functions or a dictionary with key, function pairs.

In [4]: import numpy as np
   ...: def median_sd(x):
   ...:     median = np.percentile(x, 50)
   ...:     sd = np.sqrt(np.mean((x-median)**2))
   ...:     return sd
   ...: 
   ...: func_dict = {
   ...:     "std": np.std,
   ...:     "median_std": median_sd,
   ...:     "5%": lambda x: np.percentile(x, 5),
   ...:     "median": lambda x: np.percentile(x, 50),
   ...:     "95%": lambda x: np.percentile(x, 95),
   ...: }
   ...: az.summary(
   ...:     data,
   ...:     var_names=["mu", "tau"],
   ...:     stat_funcs=func_dict,
   ...:     extend=False
   ...: )
   ...: 
Out[4]: 
       std  median_std     5%  median    95%
mu   3.371       3.374 -1.312   3.961  9.640
tau  3.000       3.113  0.785   3.258  9.659