2 Working with DataTree
During a modern Bayesian analysis we usually generate many sets of data including posterior samples and posterior predictive samples. But we also have observed data, and statistics generated by the sampling method, samples from the prior and/or prior predictive distribution, etc. To keep all this data tidy and avoid confusion ArviZ relies on the data-structures provided by (Hoyer and Hamman 2017). If you are not familiar with xarray this chapter introduced some basic elements in the context of Bayesian stats. For a deeper understanding of xarray data-structures and functionally we recommend that you check their documentation, you may find xarray useful for problems outside Bayesian analysis.
We need to become familiar with 3 Data Structures:
DataArray
: A labelled, N-dimensional array. In other words this is like NumPy but you can access the data using meaningful labels instead of numerical indexes. You may also think of this as the N-D generalization of a pandas or polarsSeries
.DataSet
: It is a dict-like container of DataArray objects aligned along any number of shared dimensions. You may also think of this as the N-D generalization of a pandas or polarsDataFrame
.DataTree
: This is a container of DataSets, each DataSet is associated with a group.
The best way to understand this data-structure is to explore them. ArviZ comes equipped with a few DataTree objects so we can start playing with them even without the need to fit a model. Let’s start by loading the centered_eight
DataTree.
In the context of Bayesian Stats a DataTree
has groups like posterior
, observed_data
, posterior_predictive
, log_likelihood
, etc.
This is an HTML representation of a DataTree, so if you are reading this from a browser you should be able to interact with it. If you click on the posterior
group you will see that we have three dimensions
, with the names chain
, draw
, and school
, you can think of dimensions as the axes of a plot. This means that the posterior samples were generated by running an MCMC sampler with 4 chains, each one of 500 draws. At least for one of the parameters in the posterior we have and additional dimension called school
. If you click on coordinates
you will be able to see the actual values that each dimension can take, like the integers [0, 1, 2, 3]
for chain and the strings ['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon']
for school
, notice that we have an array of dtype=object
. Furthermore, if you click on the symbol by the school
coordinate, you will be able to see the names of each school.
2.0.1 Get the dataset corresponding to a single group
We can access each group using a dictionary-like notation:
<xarray.DatasetView> Size: 165kB Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 16kB ... theta (chain, draw, school) float64 128kB ... tau (chain, draw) float64 16kB ... Attributes: (6)
Alternatively, we can use the dot notation, as groups are attributes of the DataTree. For instance, to access the posterior group we can write:
The dot notation works at the group level and for DataSets and DataArrays as long as there is no conflict with a method or attribute of these objects. If there is a conflict, you can always use the dictionary-like notation.
Notice that we still get a DataTree, but with 0 groups. If you want the DataSet you can do.
<xarray.Dataset> Size: 165kB Dimensions: (chain: 4, draw: 500, school: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 16kB ... theta (chain, draw, school) float64 128kB ... tau (chain, draw) float64 16kB ... Attributes: (6)
2.0.2 Get coordinate values
As we have seen, we have 8 schools with their names. If we want to programmatically access the names we can do
<xarray.DataArray 'school' (school: 8)> Size: 512B 'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon' Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Which returns a DataArray with the names of the schools. To obtain a NumPy array we can do
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter',
'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'],
dtype='<U16')
If we want to get the number of schools we can write:
Notice that we do not need to first obtain the NumPy array and then compute the length. When working with DataTree/Sets/Arrays, you may feel tempted to reduce them to NumPy arrays, as you are more familiar with those. But for many problems that is not needed and for many other that is not even recommended as you may loose the benefit of working with labeled array-like structures.
2.0.3 Get a subset of chains
Because we have labels for the names of the schools we can use them to access their associated information. Labels are usually much easier to remember than numerical indices. For instance, to access the posterior samples of the school Choate
we can write:
<xarray.DatasetView> Size: 52kB Dimensions: (chain: 4, draw: 500) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 school <U16 64B 'Choate' Data variables: mu (chain, draw) float64 16kB ... theta (chain, draw) float64 16kB ... tau (chain, draw) float64 16kB ... Attributes: (6)
The draw
and chain
coordinates are indexed using numbers, the following code will return the last draw from chain 1 and chain 2:
<xarray.DatasetView> Size: 696B Dimensions: (chain: 2, school: 8) Coordinates: * chain (chain) int64 16B 1 2 draw int64 8B 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain) float64 16B ... theta (chain, school) float64 128B ... tau (chain) float64 16B ... Attributes: (6)
Usually, in Bayesian statistics, we don’t need to access individual draws or chains, a more common operation is to select a range. For that purpose, we can use Python’s slice
function. For example, the following line of code returns the first 200 draws from all chains:
<xarray.DatasetView> Size: 66kB Dimensions: (chain: 4, draw: 201, school: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 194 195 196 197 198 199 200 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 6kB ... theta (chain, draw, school) float64 51kB ... tau (chain, draw) float64 6kB ... Attributes: (6)
Using the slice
function we can also remove the first 100 samples.
<xarray.DatasetView> Size: 132kB Dimensions: (chain: 4, draw: 400, school: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 13kB ... theta (chain, draw, school) float64 102kB ... tau (chain, draw) float64 13kB ... Attributes: (6)
We can not apply the same operations to the entire DataTree object, "draw"
is not a valid dimension for all groups. Nut we can filter those group that have "draw"
.
<xarray.DatasetView> Size: 0B Dimensions: () Data variables: *empty*
If you check the object you will see that the groups posterior
, posterior_predictive
, log_likelihood
, sample_stats
, prior
, and prior_predictive
have 400 draws compared to the original 500. The group observed_data
has not been affected because it does not have the draw
dimension.
2.0.4 Compute posterior mean
We can perform operations on the DataTree object. For instance, to compute the mean of the first 200 draws we can write:
<xarray.DatasetView> Size: 24B Dimensions: () Data variables: mu float64 8B 4.542 theta float64 8B 5.005 tau float64 8B 4.196
In NumPy, it is common to perform operations like this along a given axis. We can do the same by specifying the dimension along which we want to operate. For instance, to compute the mean along the draw
dimension we can write:
<xarray.DatasetView> Size: 864B Dimensions: (chain: 4, school: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu (chain) float64 32B 4.246 4.184 4.659 4.855 theta (chain, school) float64 256B 5.793 4.648 3.742 ... 4.14 6.74 5.275 tau (chain) float64 32B 3.682 4.247 4.656 3.912
This returns the mean for each chain and school. Can you anticipate how different this would be if the dimension was chain
instead of draw
? And what about if we use school
?
We can also specify multiple dimensions. For instance, to compute the mean along the draw
and chain
dimensions we can write:
<xarray.DatasetView> Size: 592B Dimensions: (school: 8) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' Data variables: mu float64 8B 4.486 theta (school) float64 64B 6.46 5.028 3.938 4.872 3.667 3.975 6.581 4.772 tau float64 8B 4.124
2.0.5 Combine chains and draws
Our primary goal is usually to obtain posterior samples and thus we aren’t concerned with chains and draws. In those cases, we can use the az.extract
function. This combines the chain
and draw
into a sample
coordinate which can make further operations easier. By default, az.extract
works on the posterior, but you can specify other groups using the group
argument.
<xarray.Dataset> Size: 209kB Dimensions: (sample: 2000, school: 8) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' * sample (sample) object 16kB MultiIndex * chain (sample) int64 16kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 16kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: mu (sample) float64 16kB 7.872 3.385 9.1 7.304 ... 1.767 3.486 3.404 theta (school, sample) float64 128kB 12.32 11.29 5.709 ... 8.452 1.295 tau (sample) float64 16kB 4.726 3.909 4.844 1.857 ... 2.741 2.932 4.461 Attributes: (6)
You can achieve the same result using dt.posterior.stack(sample=("chain", "draw"))
. But extract
can be more flexible because it takes care of the most common subsetting operations with MCMC samples. It can:
- Combine
chains
anddraws
- Return a subset of variables (with optional filtering with regular expressions or string matching)
- Return a subset of samples. Moreover, by default, it returns a random subset to prevent getting non-representative samples due to bad mixing.
- Access any group
To get a subsample we can specify the number of samples we want with the num_samples
argument. For instance, to get 100 samples we can write:
If you need to extract subsets from multiple groups, you should use a random seed. This will ensure that subsamples match. For example, if you do
You can inspect the samples in the posterior
and ll
variables and see that they match.
2.1 Ploting
Xarray has some plotting capabilities, for instance, we can do:
But in most scenarios calling a plotting function from ArviZ and passing the InfereceData as an argument will be a much better idea.
2.2 Add a new variable
We can add variables to existing groups. For instance, we may want to transform a parameter from the posterior. Like computing and adding the \(\log\) of the parameter \(\tau\) to the posterior group.
<xarray.Dataset> Size: 12kB Dimensions: (sample: 100, school: 8) Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' * sample (sample) object 800B MultiIndex * chain (sample) int64 800B 3 1 0 2 1 0 2 1 3 1 3 ... 1 3 0 3 2 3 0 1 0 2 1 * draw (sample) int64 800B 80 443 428 346 215 341 ... 424 169 116 493 477 Data variables: mu (sample) float64 800B 5.949 2.821 8.765 4.16 ... 1.501 1.963 2.821 theta (school, sample) float64 6kB 5.783 3.258 13.21 ... -5.54 3.351 tau (sample) float64 800B 1.637 0.8965 2.171 ... 4.327 5.192 0.8965 log_tau (sample) float64 800B 0.4928 -0.1093 0.7754 ... 1.465 1.647 -0.1093 Attributes: (6)
2.3 Advance operations with DataTrees
Now we delve into more advanced operations with DataTree. While these operations are not essential to use ArviZ, they can be useful in some cases. Exploring these advanced functionalities will help you become more familiar with DataTree and provide additional insights that may enhance your overall experience with ArviZ.
2.3.1 Compute and store posterior pushforward quantities
We use “posterior push-forward quantities” to refer to quantities that are not variables in the posterior but deterministic computations using posterior variables.
You can use xarray for these push-forward operations and store them as a new variable in the posterior group. You’ll then be able to plot them with ArviZ functions, calculate stats and diagnostics on them (like mcse
), or save and share the DataTree object with the push forward quantities included.
The first thing we are going to do is to store the posterior
group in a variable called post
to make the code more readable. And to compute the log of \(\tau\).
Compute the rolling mean of \(\log(\tau)\) with xarray.DataArray.rolling
, storing the result in the posterior:
Using xarray for push-forward calculations has all the advantages of working with xarray. It also inherits the disadvantages of working with xarray, but we believe those to be outweighed by the advantages, and we have already shown how to extract the data as NumPy arrays.
Some examples of these advantages are specifying operations with named dimensions instead of positional ones (as seen in some previous sections), automatic alignment and broadcasting of arrays (as we’ll see now), or integration with Dask (as shown in the dask_for_arviz guide).
In this cell, you will compute pairwise differences between schools on their mean effects (variable theta
). To do so, subtract the variable theta after renaming the school dimension to the original variable. Xarray then aligns and broadcasts the two variables because they have different dimensions, and the result is a 4D variable with all the pointwise differences.
Eventually, store the result in the theta_school_diff
variable. Notice that the theta_shool_diff
variable in the posterior has kept the named dimensions and coordinates:
<xarray.DatasetView> Size: 1MB Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499 * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon' * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon' Data variables: mu (chain, draw) float64 16kB 7.872 3.385 ... 3.486 3.404 theta (chain, draw, school) float64 128kB 12.32 9.905 ... 1.295 tau (chain, draw) float64 16kB 4.726 3.909 ... 2.932 4.461 log_tau (chain, draw) float64 16kB 1.553 1.363 ... 1.076 1.495 mlogtau (chain, draw) float64 16kB nan nan nan ... 1.496 1.511 theta_school_diff (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0 Attributes: (6)
2.3.2 Advanced subsetting
To select the value corresponding to the difference between the Choate and Deerfield schools do:
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)> Size: 16kB 2.415 2.156 -0.04943 1.228 3.384 9.662 ... -1.656 -0.4021 1.524 -3.372 -6.305 Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 school <U16 64B 'Choate' school_bis <U16 64B 'Deerfield'
For more advanced subsetting (the equivalent to what is sometimes called “fancy indexing” in NumPy) you need to provide the indices as DataArray
objects:
school_idx = xr.DataArray(["Choate", "Hotchkiss", "Mt. Hermon"], dims=["pairwise_school_diff"])
school_bis_idx = xr.DataArray(
["Deerfield", "Choate", "Lawrenceville"], dims=["pairwise_school_diff"]
)
post["theta_school_diff"].sel(school=school_idx, school_bis=school_bis_idx)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, pairwise_school_diff: 3)> Size: 48kB 2.415 -6.741 -1.84 2.156 -3.474 3.784 ... -2.619 6.923 -6.305 1.667 -6.641 Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 school (pairwise_school_diff) <U16 192B 'Choate' ... 'Mt. Hermon' school_bis (pairwise_school_diff) <U16 192B 'Deerfield' ... 'Lawrenceville' Dimensions without coordinates: pairwise_school_diff
Using lists or NumPy arrays instead of DataArrays does column/row-based indexing. As you can see, the result has 9 values of theta_shool_diff
instead of the 3 pairs of difference we selected in the previous cell:
post["theta_school_diff"].sel(
school=["Choate", "Hotchkiss", "Mt. Hermon"],
school_bis=["Deerfield", "Choate", "Lawrenceville"],
)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3, school_bis: 3)> Size: 144kB 2.415 0.0 -4.581 -4.326 -6.741 -11.32 ... 1.667 -6.077 -5.203 1.102 -6.641 Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 * school (school) <U16 192B 'Choate' 'Hotchkiss' 'Mt. Hermon' * school_bis (school_bis) <U16 192B 'Deerfield' 'Choate' 'Lawrenceville'
2.3.3 Add new chains using concat
After checking the mcse
and realizing you need more samples, you rerun the model with two chains and obtain an dt_rerun
object.
You can combine the two into a single DataTree object using the concat
function from ArviZ:
2.3.4 Add groups to DataTrees
This will be simplified in the future, but for now, you can add groups to a DataTree by converting the DataTree to a dictionary, adding the new group, and then converting the dictionary back to a DataTree.
rng = np.random.default_rng(3)
ds = azb.dict_to_dataset(
{"obs": rng.normal(size=(4, 500, 2))},
dims={"obs": ["new_school"]},
coords={"new_school": ["Essex College", "Moordale"]},
)
dicto = {k:v for k,v in dt.items()}
dicto["predictions"] = ds
new_dt = xr.DataTree.from_dict(dicto)
new_dt
<xarray.DatasetView> Size: 0B Dimensions: () Data variables: *empty*