arviz.InferenceData.add_groups#

InferenceData.add_groups(group_dict=None, coords=None, dims=None, **kwargs)[source]#

Add new groups to InferenceData object.

Parameters
group_dictdict of {strdict or xarray.Dataset}, optional

Groups to be added

coordsdict of {strarray_like}, optional

Coordinates for the dataset

dimsdict of {strlist of str}, optional

Dimensions of each variable. The keys are variable names, values are lists of coordinates.

kwargsdict, optional

The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.

See also

extend

Extend InferenceData with groups from another InferenceData.

concat

Concatenate InferenceData objects.

Examples

Add a log_likelihood group to the “rugby” example InferenceData after loading. It originally doesn’t have the log_likelihood group:

import arviz as az
idata = az.load_arviz_data("rugby")
idata2 = idata.copy()
post = idata.posterior
obs = idata.observed_data
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
          intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
          atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
          defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
          sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
          sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
          atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
          defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6
          tree_size         (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0
          tune              (chain, draw) bool True False False ... False False False
          mean_tree_accept  (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4
          depth             (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5
          max_energy_error  (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014
          step_size         (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (chain, draw) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922
          intercept     (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049
          atts_star     (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538
          defs_star     (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067
          away_points   (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0
          sd_att        (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251
          sd_def_log__  (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981
          home          (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651
          atts          (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365
          sd_def        (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13
          defs          (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

Knowing the model, we can compute it manually. In this case however, we will generate random samples with the right shape.

import numpy as np
rng = np.random.default_rng(73)
ary = rng.normal(size=(post.dims["chain"], post.dims["draw"], obs.dims["match"]))
idata.add_groups(
    log_likelihood={"home_points": ary},
    dims={"home_points": ["match"]},
)
idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
          intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
          atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
          defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
          sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
          sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
          atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
          defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) int64 0 1 2 3 4 5 6 7 8 ... 51 52 53 54 55 56 57 58 59
      Data variables:
          home_points  (chain, draw, match) float64 -1.093 0.7781 ... 0.2405 1.643
      Attributes:
          created_at:     2022-10-06T12:14:39.159857
          arviz_version:  0.13.0.dev0

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6
          tree_size         (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0
          tune              (chain, draw) bool True False False ... False False False
          mean_tree_accept  (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4
          depth             (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5
          max_energy_error  (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014
          step_size         (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (chain, draw) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922
          intercept     (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049
          atts_star     (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538
          defs_star     (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067
          away_points   (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0
          sd_att        (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251
          sd_def_log__  (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981
          home          (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651
          atts          (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365
          sd_def        (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13
          defs          (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

This is fine if we have raw data, but a bit inconvenient if we start with labeled data already. Why provide dims and coords manually again? Let’s generate a fake log likelihood (doesn’t match the model but it serves just the same for illustration purposes here) working from the posterior and observed_data groups manually:

import xarray as xr
from xarray_einstats.stats import XrDiscreteRV
from scipy.stats import poisson
dist = XrDiscreteRV(poisson)
log_lik = xr.Dataset()
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
idata2.add_groups({"log_likelihood": log_lik})
idata2
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:    (chain: 4, draw: 500, team: 6)
      Coordinates:
        * chain      (chain) int64 0 1 2 3
        * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
          intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
          atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
          defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
          sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
          sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
          atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
          defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
      Attributes:
          created_at:                 2019-07-12T20:31:53.545143
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 500, match: 60)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (chain, draw, match) int64 ...
          away_points  (chain, draw, match) int64 ...
      Attributes:
          created_at:                 2019-07-12T20:31:53.563854
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60, chain: 4, draw: 500, team: 6)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team         (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
      Data variables:
          home_points  (match, chain, draw, team) float64 -50.27 -53.03 ... -21.14

    • <xarray.Dataset>
      Dimensions:           (chain: 4, draw: 500)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499
      Data variables:
          energy_error      (chain, draw) float64 -0.07666 -0.4523 ... 0.115 -0.07691
          energy            (chain, draw) float64 540.2 545.3 542.3 ... 544.0 545.6
          tree_size         (chain, draw) float64 15.0 63.0 31.0 ... 63.0 31.0 31.0
          tune              (chain, draw) bool True False False ... False False False
          mean_tree_accept  (chain, draw) float64 1.0 0.8851 0.8875 ... 0.7791 0.7539
          lp                (chain, draw) float64 -536.4 -536.0 ... -536.1 -536.4
          depth             (chain, draw) int64 4 6 5 4 4 4 5 5 5 ... 6 4 6 5 3 6 5 5
          max_energy_error  (chain, draw) float64 -0.5361 -0.5871 ... 0.7109 1.014
          step_size         (chain, draw) float64 0.2469 0.2469 ... 0.2459 0.2459
          step_size_bar     (chain, draw) float64 0.2313 0.2313 ... 0.2488 0.2488
          diverging         (chain, draw) bool False False False ... False False False
      Attributes:
          created_at:                 2019-07-12T20:31:53.555203
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:       (chain: 1, draw: 500, team: 6, match: 60)
      Coordinates:
        * chain         (chain) int64 0
        * draw          (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
        * team          (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
        * match         (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          sd_att_log__  (chain, draw) float64 1.322 -2.014 1.588 ... -0.8585 -0.1922
          intercept     (chain, draw) float64 4.464 3.352 1.567 ... 4.363 4.128 1.049
          atts_star     (chain, draw, team) float64 -2.64 4.172 ... -0.2874 -0.8538
          defs_star     (chain, draw, team) float64 -0.7817 -0.1478 ... 0.1655 0.01067
          away_points   (chain, draw, match) int64 11308 0 11 1 0 21442 ... 11 1 2 2 0
          sd_att        (chain, draw) float64 3.752 0.1334 4.896 ... 0.4238 0.8251
          sd_def_log__  (chain, draw) float64 -0.2662 0.2411 0.6071 ... 1.402 -1.981
          home          (chain, draw) float64 -1.511 -0.001582 ... -0.02416 0.2651
          atts          (chain, draw, team) float64 -4.667 2.145 ... -0.2702 -0.8365
          sd_def        (chain, draw) float64 0.7663 1.273 1.835 ... 3.922 4.063 0.138
          home_points   (chain, draw, match) int64 0 47 11899 3262 1 ... 3 2 1 12 13
          defs          (chain, draw, team) float64 -0.2517 0.3823 ... 0.089 -0.06586
      Attributes:
          created_at:                 2019-07-12T20:31:53.573731
          inference_library:          pymc3
          inference_library_version:  3.7

    • <xarray.Dataset>
      Dimensions:      (match: 60)
      Coordinates:
        * match        (match) object 'Wales Italy' ... 'Ireland England'
      Data variables:
          home_points  (match) float64 23.0 26.0 28.0 26.0 0.0 ... 61.0 29.0 20.0 13.0
          away_points  (match) float64 15.0 24.0 6.0 3.0 20.0 ... 21.0 0.0 18.0 9.0
      Attributes:
          created_at:                 2019-07-12T20:31:53.581293
          inference_library:          pymc3
          inference_library_version:  3.7

Note that in the first example we have used the kwargs argument and in the second we have used the group_dict one.