arviz.extract_dataset#

arviz.extract_dataset(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, rng=None)[source]#

Extract an InferenceData group or subset of it as a xarray.Dataset.

Parameters
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be plotted, two variables are required. Prefix the variables by ~ when you want to exclude them from the plot.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns
xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract_dataset(idata)
<xarray.Dataset>
Dimensions:  (school: 8, sample: 2000)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) MultiIndex
  - chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 0 3 3 3 3 3 3 3 3 3 3 3 3 3
  - draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614
    theta    (school, sample) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523
    tau      (sample) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

You can also indicate a subset to be returned, but in variables and in samples:

az.extract_dataset(idata, var_names="theta", num_samples=100)
<xarray.Dataset>
Dimensions:  (school: 8, sample: 100)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) MultiIndex
  - chain    (sample) int64 3 1 2 1 3 2 0 0 0 1 1 1 ... 3 0 3 1 1 1 3 2 2 3 2 2
  - draw     (sample) int64 448 343 477 347 438 153 ... 329 241 298 436 156 48
Data variables:
    theta    (school, sample) float64 7.55 3.614 3.295 ... 0.4809 1.109 7.632
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

To keep the chain and draw dimensions, use combined=False.

az.extract_dataset(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:    (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
    tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631
    tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704
    mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273
    theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555
    obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071
Attributes:
    created_at:                 2019-06-21T17:36:34.490387
    inference_library:          pymc3
    inference_library_version:  3.7