arviz.extract#

arviz.extract(data, group='posterior', combined=True, var_names=None, filter_vars=None, num_samples=None, keep_dataset=False, rng=None)[source]#

Extract an InferenceData group or subset of it.

Parameters
idataInferenceData or InferenceData_like

InferenceData from which to extract the data.

groupstr, optional

Which InferenceData data group to extract data from.

combinedbool, optional

Combine chain and draw dimensions into sample. Won’t work if a dimension named sample already exists.

var_namesstr or list of str, optional

Variables to be extracted. Prefix the variables by ~ when you want to exclude them.

filter_vars: {None, “like”, “regex”}, optional

If None (default), interpret var_names as the real variables names. If “like”, interpret var_names as substrings of the real variables names. If “regex”, interpret var_names as regular expressions on the real variables names. A la pandas.filter. Like with plotting, sometimes it’s easier to subset saying what to exclude instead of what to include

num_samplesint, optional

Extract only a subset of the samples. Only valid if combined=True

keep_datasetbool, optional

If true, always return a DataSet. If false (default) return a DataArray when there is a single variable.

rngbool, int, numpy.Generator, optional

Shuffle the samples, only valid if combined=True. By default, samples are shuffled if num_samples is not None, and are left in the same order otherwise. This ensures that subsetting the samples doesn’t return only samples from a single chain and consecutive draws.

Returns
xarray.DataArray or xarray.Dataset

Examples

The default behaviour is to return the posterior group after stacking the chain and draw dimensions.

import arviz as az
idata = az.load_arviz_data("centered_eight")
az.extract(idata)
<xarray.Dataset>
Dimensions:  (school: 8, sample: 2000)
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 3
  * draw     (sample) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
Data variables:
    mu       (sample) float64 -3.477 -2.456 -2.826 -1.996 ... 4.597 5.899 0.1614
    theta    (school, sample) float64 1.669 -6.239 2.195 ... -1.095 4.013 4.523
    tau      (sample) float64 3.73 2.075 3.703 4.146 ... 8.589 8.346 7.711 5.407
Attributes:
    created_at:                 2019-06-21T17:36:34.398087
    inference_library:          pymc3
    inference_library_version:  3.7

You can also indicate a subset to be returned, but in variables and in samples:

az.extract(idata, var_names="theta", num_samples=100)
<xarray.DataArray 'theta' (school: 8, sample: 100)>
array([[ 9.47797273e+00,  8.07394570e+00,  6.57739555e+00,
         4.44629688e+00,  2.31012821e+00,  1.14628350e+01,
         1.21870955e+01,  1.21400044e+01,  5.51004269e+00,
         1.18884236e+01,  7.52389042e+00,  1.40280761e+01,
         1.22481857e+01,  8.47585134e+00,  2.82681087e+00,
         2.22251050e+01,  8.32077501e+00,  1.20843068e+01,
         1.06155160e+00,  1.45233093e+01, -1.27726030e+00,
         3.80132180e+00,  3.04669257e+00,  4.85693582e+00,
         1.72680493e-01,  2.68244068e+00, -3.22052364e+00,
         2.47423848e+00,  5.07675798e+00,  5.60073302e-01,
         5.11231058e+00,  1.70405474e+00,  6.93268859e+00,
         8.46832328e+00,  1.12449759e+01,  5.36165347e+00,
         6.43591849e-01, -5.87679594e-02,  3.28567671e+00,
         9.64280953e+00,  1.16342430e+01,  5.44104521e+00,
         1.10787348e+01,  6.99380298e+00,  4.36693773e+00,
         1.73021909e+01,  3.84964965e+00,  5.51051257e+00,
         9.14227949e+00,  4.64742585e+00,  4.94933880e+00,
         1.23046196e+01,  9.83187295e+00,  2.86944710e+01,
        -2.26498561e-01, -2.62736064e+00,  7.99486485e+00,
         9.76957118e+00, -2.65331395e-01,  6.95249013e+00,
...
        -9.25441916e-01,  7.99275901e+00, -3.04138574e+00,
         1.98826185e+01, -5.39859896e-01,  4.02215059e+00,
         1.09270506e+01,  1.03642396e+01,  1.53432200e+01,
         5.30379149e+00,  7.23149524e+00,  1.33883512e+01,
        -2.42070065e+00, -4.48112554e+00,  9.46543871e+00,
         1.24956179e+01,  3.30441490e+00,  2.17010592e+01,
         3.27090580e+00,  5.26479115e+00,  1.85766017e+00,
         3.18741926e+00,  1.68573362e+01,  5.47246773e+00,
         4.33474724e+00,  2.32796811e+00,  8.38390126e+00,
         1.15578350e+01,  9.29032030e+00,  8.52082240e+00,
         3.17016285e+00,  4.93396155e-01,  3.82081193e+00,
         6.12213726e+00, -1.58673174e-01,  8.54297503e-01,
         5.25322948e+00,  2.80196072e-01,  9.22787325e+00,
        -1.23851551e+00,  8.99318700e+00,  1.12728305e+01,
         3.64654434e+00,  2.84331668e+00,  1.00132946e+01,
         6.61202173e+00,  1.96462631e+00,  2.98455306e+00,
         7.53618424e+00,  2.84415240e+00, -2.50318575e+00,
         9.22106981e+00, -1.39248206e+00,  6.00918339e+00,
        -8.81937036e-01,  8.17238390e+00,  3.73069157e+00,
         1.29453791e+00]])
Coordinates:
  * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
  * sample   (sample) object MultiIndex
  * chain    (sample) int64 0 1 3 2 1 2 1 2 1 1 0 1 ... 0 0 0 1 1 0 2 3 0 0 0 0
  * draw     (sample) int64 188 35 32 237 156 214 135 ... 56 85 384 248 124 191

To keep the chain and draw dimensions, use combined=False.

az.extract(idata, group="prior", combined=False)
<xarray.Dataset>
Dimensions:    (chain: 1, draw: 500, school: 8)
Coordinates:
  * chain      (chain) int64 0
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * school     (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
    tau        (chain, draw) float64 6.561 1.016 68.91 ... 1.56 5.949 0.7631
    tau_log__  (chain, draw) float64 1.881 0.01593 4.233 ... 1.783 -0.2704
    mu         (chain, draw) float64 5.293 0.8137 0.7122 ... -1.658 -3.273
    theta      (chain, draw, school) float64 2.357 7.371 7.251 ... -3.775 -3.555
    obs        (chain, draw, school) float64 -3.54 6.769 19.68 ... -21.16 -6.071
Attributes:
    created_at:                 2019-06-21T17:36:34.490387
    inference_library:          pymc3
    inference_library_version:  3.7