arviz.InferenceData.sel#

InferenceData.sel(groups: Optional[Union[str, List[str]]] = None, filter_groups: Optional[Literal['like', 'regex']] = None, inplace: bool = False, chain_prior: Optional[bool] = None, **kwargs: Any) Optional[arviz.data.inference_data.InferenceDataT][source]#

Perform an xarray selection on all groups.

Loops groups to perform Dataset.sel(key=item) for every kwarg if key is a dimension of the dataset. One example could be performing a burn in cut on the InferenceData object or discarding a chain. The selection is performed on all relevant groups (like posterior, prior, sample stats) while non relevant groups like observed data are omitted. See xarray.Dataset.sel

Parameters
groupsstr or list of str, optional

Groups where the selection is to be applied. Can either be group names or metagroup names.

filter_groups{None, “like”, “regex”}, optional, default=None

If None (default), interpret groups as the real group or metagroup names. If “like”, interpret groups as substrings of the real group or metagroup names. If “regex”, interpret groups as regular expressions on the real group or metagroup names. A la pandas.filter.

inplacebool, optional

If True, modify the InferenceData object inplace, otherwise, return the modified copy.

chain_priorbool, optional, deprecated

If False, do not select prior related groups using chain dim. Otherwise, use selection on chain if present. Default=False

kwargsdict, optional

It must be accepted by Dataset.sel().

Returns
InferenceData

A new InferenceData object by default. When inplace==True perform selection in-place and return None

See also

xarray.Dataset.sel

Returns a new dataset with each array indexed by tick labels along the specified dimension(s).

isel

Returns a new dataset with each array indexed along the specified dimension(s).

Examples

Use sel to discard one chain of the InferenceData object. We first check the dimensions of the original object:

import arviz as az
idata = az.load_arviz_data("centered_eight")
del idata.prior  # prior group only has 1 chain currently
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614
          theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523
          tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407
      Attributes:
          created_at:                 2019-06-21T17:36:34.398087
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 4.698 -15.07
      Attributes:
          created_at:                 2019-06-21T17:36:34.489022
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500, school: 8)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          tune              (chain, draw) bool True False False ... False False False
          depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 5 5 5 5
          tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 31.0 31.0 31.0
          lp                (chain, draw) float64 -59.05 -56.19 ... -63.62 -58.35
          energy_error      (chain, draw) float64 0.07387 -0.1841 ... -0.087 -0.003652
          step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502
          max_energy_error  (chain, draw) float64 0.131 -0.2067 ... -0.101 -0.1757
          energy            (chain, draw) float64 60.76 62.76 64.4 ... 67.77 67.21
          mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.9875 0.9967
          step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064
          diverging         (chain, draw) bool False False False ... False False False
          log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.896
      Attributes:
          created_at:                 2019-06-21T17:36:34.485802
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:                 2019-06-21T17:36:34.491909
          inference_library:          pymc3
          inference_library_version:  3.7

In order to remove the third chain:

idata_subset = idata.sel(chain=[0, 1, 3])
idata_subset
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 3, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          mu       (chain, draw) float64 -3.477 -2.456 -2.826 ... 4.597 5.899 0.1614
          theta    (chain, draw, school) float64 1.669 -8.537 -2.623 ... 10.59 4.523
          tau      (chain, draw) float64 3.73 2.075 3.703 4.146 ... 8.346 7.711 5.407
      Attributes:
          created_at:                 2019-06-21T17:36:34.398087
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (chain: 3, draw: 500, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (chain, draw, school) float64 7.85 -19.03 -22.5 ... 4.698 -15.07
      Attributes:
          created_at:                 2019-06-21T17:36:34.489022
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 3, draw: 500, school: 8)
      Coordinates:
        * chain             (chain) int64 0 1 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
        * school            (school) object 'Choate' 'Deerfield' ... 'Mt. Hermon'
      Data variables:
          tune              (chain, draw) bool True False False ... False False False
          depth             (chain, draw) int64 5 3 3 4 5 5 4 4 5 ... 4 4 4 5 5 5 5 5
          tree_size         (chain, draw) float64 31.0 7.0 7.0 15.0 ... 31.0 31.0 31.0
          lp                (chain, draw) float64 -59.05 -56.19 ... -63.62 -58.35
          energy_error      (chain, draw) float64 0.07387 -0.1841 ... -0.087 -0.003652
          step_size_bar     (chain, draw) float64 0.2417 0.2417 ... 0.1502 0.1502
          max_energy_error  (chain, draw) float64 0.131 -0.2067 ... -0.101 -0.1757
          energy            (chain, draw) float64 60.76 62.76 64.4 ... 67.77 67.21
          mean_tree_accept  (chain, draw) float64 0.9506 0.9906 ... 0.9875 0.9967
          step_size         (chain, draw) float64 0.1275 0.1275 ... 0.1064 0.1064
          diverging         (chain, draw) bool False False False ... False False False
          log_likelihood    (chain, draw, school) float64 -5.168 -4.589 ... -3.896
      Attributes:
          created_at:                 2019-06-21T17:36:34.485802
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) object 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:                 2019-06-21T17:36:34.491909
          inference_library:          pymc3
          inference_library_version:  3.7