ArviZ.jl Quickstart
This tutorial is adapted from ArviZ's quickstart.
using ArviZ
using PyPlot
# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")
Get started with plotting
ArviZ.jl is designed to be used with libraries like CmdStan, Turing.jl, and Soss.jl but works fine with raw arrays.
using Random
rng = Random.MersenneTwister(37772)
plot_posterior(randn(rng, 100_000));
gcf()
Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.
using Distributions
s = (10, 50)
plot_forest(
Dict(
"normal" => randn(rng, s),
"gumbel" => rand(rng, Gumbel(), s),
"student t" => rand(rng, TDist(6), s),
"exponential" => rand(rng, Exponential(), s),
),
);
gcf()
Plotting with MCMCChains.jl's Chains
objects produced by Turing.jl
ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.
This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).
First we create our data and set some sampling parameters.
J = 8
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
schools = [
"Choate",
"Deerfield",
"Phillips Andover",
"Phillips Exeter",
"Hotchkiss",
"Lawrenceville",
"St. Paul's",
"Mt. Hermon",
];
nwarmup, nsamples, nchains = 1000, 1000, 4;
Now we write and run the model using Turing:
using Turing
Turing.@model function turing_model(y, σ, J=length(y))
μ ~ Normal(0, 5)
τ ~ truncated(Cauchy(0, 5), 0, Inf)
θ ~ filldist(Normal(μ, τ), J)
for i in 1:J
y[i] ~ Normal(θ[i], σ[i])
end
end
param_mod = turing_model(y, σ)
sampler = NUTS(nwarmup, 0.8)
rng = Random.MersenneTwister(16653)
turing_chns = sample(
rng, param_mod, sampler, MCMCThreads(), nsamples, nchains; progress=false
);
┌ Info: Found initial step size └ ϵ = 0.8 ┌ Info: Found initial step size └ ϵ = 0.8 ┌ Info: Found initial step size └ ϵ = 0.30000000000000004 ┌ Info: Found initial step size └ ϵ = 0.4
Most ArviZ functions work fine with Chains
objects from Turing:
plot_autocorr(turing_chns; var_names=["μ", "τ"]);
gcf()
Convert to InferenceData
For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains
objects to xarray datasets. Note we are also giving some information about labelling.
ArviZ is built to work with InferenceData
(a netcdf datastore that loads data into xarray
datasets), and the more groups it has access to, the more powerful analyses it can perform.
idata = from_mcmcchains(
turing_chns;
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float643.598 3.844 3.727 ... 3.701 3.061
array([[3.59788119, 3.84404447, 3.72659752, ..., 2.49317966, 8.3948064 , 4.1847385 ], [6.74479861, 1.89308501, 1.68608344, ..., 8.92633692, 8.03694075, 8.24967118], [2.71316418, 1.81724698, 0.89495461, ..., 1.42616574, 6.46674741, 5.50683855], [5.35376076, 9.09849583, 6.15161124, ..., 7.41352908, 3.70148649, 3.06102253]])
- τ(chain, draw)float641.341 1.253 0.9012 ... 4.701 3.185
array([[1.34073038, 1.25337506, 0.90120511, ..., 7.71593232, 3.35103068, 4.90820423], [1.38723257, 1.07060803, 1.07632633, ..., 0.62289351, 0.69431774, 0.70327199], [6.62216688, 5.71013263, 5.72604529, ..., 1.50118592, 1.14275532, 1.39495226], [5.17607165, 6.76242035, 7.11944572, ..., 2.30431877, 4.70052024, 3.18480796]])
- θ(chain, draw, school)float643.826 4.778 2.754 ... 5.444 7.008
array([[[ 3.82580815, 4.77820673, 2.75367015, ..., 5.20000611, 3.99858116, 3.3442447 ], [ 3.62539307, 2.66116809, 4.59044101, ..., 1.97466523, 4.06388155, 4.19595353], [ 4.18929573, 4.68331561, 1.42165618, ..., 4.46481228, 2.32853066, 3.97931003], ..., [ 9.57228401, -5.83372229, 4.99980935, ..., -3.09167033, 8.13483873, 10.24229864], [ 7.55986707, 15.27362794, 5.63137058, ..., 12.43570471, 10.19511683, 5.03023247], [ 8.69651861, 11.07514898, 3.34087843, ..., 3.65042769, 3.82075123, 5.36780364]], [[ 6.82596806, 7.3238828 , 6.76661356, ..., 7.87052869, 5.78928781, 9.26406562], [ 2.07072686, 1.59750543, 2.38958599, ..., 2.07977746, 1.43393265, 3.17468462], [ 2.08593606, 1.96540374, 1.79040365, ..., 2.29121475, -0.64260868, 1.95535837], ... [ 2.16458635, 3.37747874, 0.5317364 , ..., 2.81337382, 3.05885857, 2.00957334], [ 5.93003421, 6.43329873, 6.70378137, ..., 5.12755634, 4.73725274, 6.32657518], [ 7.43450593, 6.57152762, 4.67224058, ..., 5.45062127, 6.921223 , 5.38587665]], [[ 4.87781846, 13.12236009, 12.8301631 , ..., 11.02199023, 0.35612231, 5.18124826], [18.39069418, 4.57760771, 5.07768066, ..., 5.23149171, 24.65944989, 13.54127658], [-5.86228011, 16.0856563 , 7.17535026, ..., 10.86542645, 5.61971337, 12.45456404], ..., [ 6.44203777, 3.36532827, 8.41804971, ..., 8.94489362, 6.61320513, 5.32864013], [ 4.33549352, 4.85275073, -0.50413873, ..., 6.70494189, 7.72973489, 7.69134745], [ 8.14187841, 3.10493155, -0.86228044, ..., 7.34052106, 5.444251 , 7.00828561]]])
- created_at :
- 2021-05-11T17:20:41.996609
- arviz_version :
- 0.11.2
- start_time :
- 1620753599.753838
- stop_time :
- 1620753617.297758
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061 τ (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185 θ (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008 Attributes: created_at: 2021-05-11T17:20:41.996609 arviz_version: 0.11.2 start_time: 1620753599.753838 stop_time: 1620753617.297758 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- energy(chain, draw)float6449.74 48.67 52.26 ... 61.18 57.5
array([[49.7370628 , 48.67232666, 52.25838724, ..., 67.8317717 , 62.75304092, 60.02786579], [51.89725115, 50.65824036, 49.01952976, ..., 52.53422526, 49.98759606, 53.58493908], [62.85460254, 63.96796014, 65.56422743, ..., 54.091911 , 52.40117012, 51.66209205], [63.65406847, 65.15351931, 68.16517778, ..., 60.68044485, 61.18089267, 57.50421029]])
- energy_error(chain, draw)float64-0.1874 0.05797 ... 0.02483
array([[-0.18741771, 0.05796766, 0.15964891, ..., -0.0446108 , -0.02215499, -0.00177725], [ 0.00802054, -0.12422556, 0.06052729, ..., -0.08477852, -0.05075087, 0.00127109], [ 0.01644318, 0.01161982, 0.01507367, ..., -0.18487184, -0.10435632, 0.05729024], [ 0.04659189, -0.02628029, 0.0296357 , ..., 0.16822564, -0.02117395, 0.02483316]])
- tree_depth(chain, draw)int643 3 3 4 5 2 3 3 ... 4 4 4 4 3 4 4 3
array([[3, 3, 3, ..., 5, 5, 4], [5, 5, 3, ..., 5, 4, 3], [3, 5, 4, ..., 3, 4, 3], [5, 5, 5, ..., 4, 4, 3]], dtype=int64)
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- step_size_nom(chain, draw)float640.1222 0.1288 ... 0.1542 0.1542
array([[0.12222895, 0.12882546, 0.12882546, ..., 0.12882546, 0.12882546, 0.12882546], [0.09536174, 0.08711545, 0.08711545, ..., 0.08711545, 0.08711545, 0.08711545], [0.44847013, 0.12865152, 0.12865152, ..., 0.12865152, 0.12865152, 0.12865152], [0.15398524, 0.15422793, 0.15422793, ..., 0.15422793, 0.15422793, 0.15422793]])
- acceptance_rate(chain, draw)float640.9568 0.9265 ... 0.9876 0.991
array([[0.95681006, 0.92652295, 0.89847433, ..., 0.99538872, 0.99996969, 0.99860911], [0.9615556 , 0.99204205, 0.97278553, ..., 0.98857838, 0.93462125, 0.58692603], [0.99522439, 0.98606921, 0.98539153, ..., 0.94774068, 0.96067237, 0.88091718], [0.97263938, 0.99029006, 0.97104278, ..., 0.9115109 , 0.98760612, 0.99095234]])
- log_density(chain, draw)float64-46.53 -46.78 ... -55.97 -54.9
array([[-46.53483252, -46.78134226, -47.77320597, ..., -59.68568187, -57.11756645, -55.23444383], [-47.83856136, -45.60857601, -46.70892238, ..., -46.17375312, -44.74105012, -43.90783267], [-59.08957017, -59.15059099, -63.11189166, ..., -48.26350364, -45.76296103, -46.38602892], [-59.90399995, -61.05616853, -63.72313774, ..., -55.10509419, -55.96714016, -54.90175186]])
- max_energy_error(chain, draw)float64-0.8297 0.1387 ... -0.2598 -0.5436
array([[-0.82967575, 0.13869524, -0.52671218, ..., -0.09622386, -0.04049941, -0.12825035], [-0.20186425, -0.40015155, 0.06052729, ..., -0.51532863, -0.37182162, 1.24694298], [-0.26971113, 0.04007618, -0.30354099, ..., -0.50058604, -2.21065078, 0.52637854], [ 0.05323016, -0.05321337, 0.04958443, ..., 0.26218872, -0.25977127, -0.54362938]])
- is_accept(chain, draw)boolTrue True True ... True True True
array([[ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True]])
- lp(chain, draw)float64-46.53 -46.78 ... -55.97 -54.9
array([[-46.53483252, -46.78134226, -47.77320597, ..., -59.68568187, -57.11756645, -55.23444383], [-47.83856136, -45.60857601, -46.70892238, ..., -46.17375312, -44.74105012, -43.90783267], [-59.08957017, -59.15059099, -63.11189166, ..., -48.26350364, -45.76296103, -46.38602892], [-59.90399995, -61.05616853, -63.72313774, ..., -55.10509419, -55.96714016, -54.90175186]])
- step_size(chain, draw)float640.1222 0.1288 ... 0.1542 0.1542
array([[0.12222895, 0.12882546, 0.12882546, ..., 0.12882546, 0.12882546, 0.12882546], [0.09536174, 0.08711545, 0.08711545, ..., 0.08711545, 0.08711545, 0.08711545], [0.44847013, 0.12865152, 0.12865152, ..., 0.12865152, 0.12865152, 0.12865152], [0.15398524, 0.15422793, 0.15422793, ..., 0.15422793, 0.15422793, 0.15422793]])
- n_steps(chain, draw)int6415 7 7 31 47 7 ... 31 15 15 15 15
array([[15, 7, 7, ..., 31, 31, 31], [47, 39, 7, ..., 63, 15, 7], [15, 31, 31, ..., 15, 31, 7], [31, 31, 31, ..., 15, 15, 15]], dtype=int64)
- created_at :
- 2021-05-11T17:20:42.032613
- arviz_version :
- 0.11.2
- start_time :
- 1620753599.753838
- stop_time :
- 1620753617.297758
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: energy (chain, draw) float64 49.74 48.67 52.26 ... 61.18 57.5 energy_error (chain, draw) float64 -0.1874 0.05797 ... -0.02117 0.02483 tree_depth (chain, draw) int64 3 3 3 4 5 2 3 3 2 ... 4 4 4 4 3 4 4 3 diverging (chain, draw) bool False False False ... False False False step_size_nom (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542 acceptance_rate (chain, draw) float64 0.9568 0.9265 ... 0.9876 0.991 log_density (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9 max_energy_error (chain, draw) float64 -0.8297 0.1387 ... -0.2598 -0.5436 is_accept (chain, draw) bool True True True True ... True True True lp (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9 step_size (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542 n_steps (chain, draw) int64 15 7 7 31 47 7 ... 15 31 15 15 15 15 Attributes: created_at: 2021-05-11T17:20:42.032613 arviz_version: 0.11.2 start_time: 1620753599.753838 stop_time: 1620753617.297758 inference_library: Turing
Dataset (xarray.Dataset)
Each group is an ArviZ.Dataset
(a thinly wrapped xarray.Dataset
). We can view a summary of the dataset.
idata.posterior
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061 τ (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185 θ (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008 Attributes: created_at: 2021-05-11T17:20:41.996609 arviz_version: 0.11.2 start_time: 1620753599.753838 stop_time: 1620753617.297758 inference_library: Turing
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float643.598 3.844 3.727 ... 3.701 3.061
array([[3.59788119, 3.84404447, 3.72659752, ..., 2.49317966, 8.3948064 , 4.1847385 ], [6.74479861, 1.89308501, 1.68608344, ..., 8.92633692, 8.03694075, 8.24967118], [2.71316418, 1.81724698, 0.89495461, ..., 1.42616574, 6.46674741, 5.50683855], [5.35376076, 9.09849583, 6.15161124, ..., 7.41352908, 3.70148649, 3.06102253]])
- τ(chain, draw)float641.341 1.253 0.9012 ... 4.701 3.185
array([[1.34073038, 1.25337506, 0.90120511, ..., 7.71593232, 3.35103068, 4.90820423], [1.38723257, 1.07060803, 1.07632633, ..., 0.62289351, 0.69431774, 0.70327199], [6.62216688, 5.71013263, 5.72604529, ..., 1.50118592, 1.14275532, 1.39495226], [5.17607165, 6.76242035, 7.11944572, ..., 2.30431877, 4.70052024, 3.18480796]])
- θ(chain, draw, school)float643.826 4.778 2.754 ... 5.444 7.008
array([[[ 3.82580815, 4.77820673, 2.75367015, ..., 5.20000611, 3.99858116, 3.3442447 ], [ 3.62539307, 2.66116809, 4.59044101, ..., 1.97466523, 4.06388155, 4.19595353], [ 4.18929573, 4.68331561, 1.42165618, ..., 4.46481228, 2.32853066, 3.97931003], ..., [ 9.57228401, -5.83372229, 4.99980935, ..., -3.09167033, 8.13483873, 10.24229864], [ 7.55986707, 15.27362794, 5.63137058, ..., 12.43570471, 10.19511683, 5.03023247], [ 8.69651861, 11.07514898, 3.34087843, ..., 3.65042769, 3.82075123, 5.36780364]], [[ 6.82596806, 7.3238828 , 6.76661356, ..., 7.87052869, 5.78928781, 9.26406562], [ 2.07072686, 1.59750543, 2.38958599, ..., 2.07977746, 1.43393265, 3.17468462], [ 2.08593606, 1.96540374, 1.79040365, ..., 2.29121475, -0.64260868, 1.95535837], ... [ 2.16458635, 3.37747874, 0.5317364 , ..., 2.81337382, 3.05885857, 2.00957334], [ 5.93003421, 6.43329873, 6.70378137, ..., 5.12755634, 4.73725274, 6.32657518], [ 7.43450593, 6.57152762, 4.67224058, ..., 5.45062127, 6.921223 , 5.38587665]], [[ 4.87781846, 13.12236009, 12.8301631 , ..., 11.02199023, 0.35612231, 5.18124826], [18.39069418, 4.57760771, 5.07768066, ..., 5.23149171, 24.65944989, 13.54127658], [-5.86228011, 16.0856563 , 7.17535026, ..., 10.86542645, 5.61971337, 12.45456404], ..., [ 6.44203777, 3.36532827, 8.41804971, ..., 8.94489362, 6.61320513, 5.32864013], [ 4.33549352, 4.85275073, -0.50413873, ..., 6.70494189, 7.72973489, 7.69134745], [ 8.14187841, 3.10493155, -0.86228044, ..., 7.34052106, 5.444251 , 7.00828561]]])
- created_at :
- 2021-05-11T17:20:41.996609
- arviz_version :
- 0.11.2
- start_time :
- 1620753599.753838
- stop_time :
- 1620753617.297758
- inference_library :
- Turing
Here is a plot of the trace. Note the intelligent labels.
plot_trace(idata);
gcf()
We can also generate summary stats
summarystats(idata)
variable | mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|---|
String | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | μ | 4.377 | 3.433 | -2.29 | 10.685 | 0.126 | 0.092 | 743.0 | 881.0 | 1.01 |
2 | τ | 3.649 | 3.0 | 0.407 | 8.922 | 0.184 | 0.13 | 161.0 | 106.0 | 1.01 |
3 | θ[1] | 6.233 | 5.668 | -3.119 | 17.324 | 0.206 | 0.145 | 787.0 | 1324.0 | 1.01 |
4 | θ[2] | 4.968 | 4.668 | -3.795 | 13.939 | 0.138 | 0.105 | 1087.0 | 1834.0 | 1.01 |
5 | θ[3] | 3.945 | 5.228 | -5.885 | 13.905 | 0.15 | 0.113 | 1126.0 | 1586.0 | 1.01 |
6 | θ[4] | 4.667 | 4.889 | -4.779 | 13.732 | 0.147 | 0.115 | 1042.0 | 1677.0 | 1.0 |
7 | θ[5] | 3.542 | 4.797 | -5.586 | 12.766 | 0.135 | 0.117 | 1192.0 | 1511.0 | 1.01 |
8 | θ[6] | 4.044 | 4.768 | -4.807 | 13.581 | 0.145 | 0.108 | 1024.0 | 1348.0 | 1.01 |
9 | θ[7] | 6.296 | 5.07 | -2.371 | 16.512 | 0.189 | 0.133 | 727.0 | 1258.0 | 1.01 |
10 | θ[8] | 4.766 | 5.289 | -6.104 | 14.186 | 0.154 | 0.127 | 1047.0 | 1611.0 | 1.01 |
and examine the energy distribution of the Hamiltonian sampler
plot_energy(idata);
gcf()
Additional information in Turing.jl
With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData
.
To sample from the prior, one simply calls sample
but with the Prior
sampler:
prior = sample(param_mod, Prior(), nsamples; progress=false)
Chains MCMC chain (1000×11×1 Array{Float64, 3}): Start time = 2021-05-11T17:20:52.494 Stop time = 2021-05-11T17:20:52.760 Wall duration = 0.27 seconds Iterations = 1:1000 Thinning interval = 1 Chains = 1 Samples per chain = 1000 parameters = θ[1], θ[2], θ[3], θ[4], θ[5], θ[6], θ[7], θ[8], μ, τ internals = lp Summary Statistics parameters mean std naive_se mcse ess rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ θ[1] 0.9495 67.9183 2.1478 1.8926 556.5266 1.0001 ⋯ θ[2] -0.6053 48.9281 1.5472 1.0120 1024.0944 1.0000 ⋯ θ[3] -1.1761 60.8375 1.9238 1.5571 999.5217 0.9991 ⋯ θ[4] 0.1684 52.6024 1.6634 1.8087 811.6064 1.0025 ⋯ θ[5] 1.7505 71.2369 2.2527 2.2212 918.8990 0.9993 ⋯ θ[6] 2.0359 74.2804 2.3490 2.5844 1004.3982 1.0023 ⋯ θ[7] 1.1503 53.3496 1.6871 1.6193 1044.4259 0.9991 ⋯ θ[8] 0.7439 58.6209 1.8538 1.9887 965.1255 1.0011 ⋯ μ 0.0522 4.8552 0.1535 0.1401 1030.5527 1.0003 ⋯ τ 18.2545 68.3487 2.1614 2.4279 638.0227 0.9996 ⋯ 1 column omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 θ[1] -53.1021 -5.5494 0.0282 5.4924 42.0920 θ[2] -47.9095 -4.9937 0.2079 6.0771 38.2048 θ[3] -48.4412 -5.0144 0.3014 5.8416 39.0567 θ[4] -36.2510 -4.9584 0.4786 6.0296 46.7597 θ[5] -41.1661 -5.4494 -0.3680 5.0671 53.3500 θ[6] -40.9711 -5.4135 0.2481 5.6112 47.2051 θ[7] -39.5905 -5.2035 0.4159 6.4169 50.1237 θ[8] -49.2875 -4.9863 0.2812 5.8763 55.9611 μ -9.7057 -3.3349 -0.0951 3.3130 9.5634 τ 0.1994 2.0832 4.9683 11.2756 112.9419
To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing
, and then calling predict
on the predictive model and the previously drawn samples:
# Instantiate the predictive model
param_mod_predict = turing_model(similar(y, Missing), σ)
# and then sample!
prior_predictive = predict(param_mod_predict, prior)
posterior_predictive = predict(param_mod_predict, turing_chns)
Chains MCMC chain (1000×8×4 Array{Float64, 3}): Iterations = 1:1000 Thinning interval = 1 Chains = 1, 2, 3, 4 Samples per chain = 1000 parameters = y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8] internals = Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 y[1] 6.1940 15.8365 0.2504 0.3187 2598.9319 0.9996 y[2] 5.1315 11.0163 0.1742 0.2020 2972.2101 1.0004 y[3] 4.1674 16.8618 0.2666 0.2583 3529.7049 1.0003 y[4] 4.7150 12.1285 0.1918 0.2096 2859.1765 0.9996 y[5] 3.5240 10.2518 0.1621 0.2029 2745.2054 1.0010 y[6] 4.1669 11.9747 0.1893 0.2232 3018.5444 0.9996 y[7] 6.5947 11.0310 0.1744 0.2484 2290.3377 1.0000 y[8] 5.0596 18.4197 0.2912 0.2796 3201.2609 0.9999 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 y[1] -25.4038 -4.0195 6.3500 16.6845 37.8474 y[2] -16.5638 -2.5435 5.3023 12.6178 26.2114 y[3] -29.6595 -7.0289 4.1721 15.3849 37.5304 y[4] -19.2035 -3.6940 4.5058 12.9910 28.7925 y[5] -16.5995 -3.2405 3.5815 10.3291 23.3610 y[6] -18.4540 -3.8768 4.2179 12.0992 27.4589 y[7] -14.3394 -0.7633 6.5626 13.9443 28.7613 y[8] -30.7903 -7.2315 5.0240 17.7800 41.0833
And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo
,
loglikelihoods = Turing.pointwise_loglikelihoods(
param_mod, MCMCChains.get_sections(turing_chns, :parameters)
)
Dict{String, Matrix{Float64}} with 8 entries: "y[6]" => [-3.38973 -3.51189 -3.43534 -3.73188; -3.32076 -3.32165 -3.38096 -3… "y[2]" => [-3.27342 -3.22381 -3.30609 -3.35272; -3.36404 -3.42648 -3.23499 -3… "y[1]" => [-4.92564 -4.6233 -4.15682 -4.81507; -4.94726 -5.12105 -5.63477 -3.… "y[5]" => [-3.17195 -3.55913 -3.55442 -3.64785; -3.34779 -3.17499 -3.21129 -3… "y[8]" => [-3.92493 -3.82086 -4.23278 -3.88106; -3.9033 -3.92951 -4.07239 -3.… "y[7]" => [-4.20172 -3.96703 -3.49247 -4.77806; -4.1926 -4.5937 -4.66271 -3.4… "y[3]" => [-3.75618 -3.87783 -3.69645 -4.18097; -3.80406 -3.74826 -3.70795 -3… "y[4]" => [-3.38347 -3.32362 -3.32345 -3.31684; -3.33491 -3.50791 -3.35802 -3…
This can then be included in the from_mcmcchains
call from above:
using LinearAlgebra
# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
ynames = string.(keys(posterior_predictive))
loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames)
# Reshape into `(nchains, nsamples, size(y)...)`
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3))
idata = from_mcmcchains(
turing_chns;
posterior_predictive=posterior_predictive,
log_likelihood=Dict("y" => loglikelihoods_arr),
prior=prior,
prior_predictive=prior_predictive,
observed_data=Dict("y" => y),
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float643.598 3.844 3.727 ... 3.701 3.061
array([[3.59788119, 3.84404447, 3.72659752, ..., 2.49317966, 8.3948064 , 4.1847385 ], [6.74479861, 1.89308501, 1.68608344, ..., 8.92633692, 8.03694075, 8.24967118], [2.71316418, 1.81724698, 0.89495461, ..., 1.42616574, 6.46674741, 5.50683855], [5.35376076, 9.09849583, 6.15161124, ..., 7.41352908, 3.70148649, 3.06102253]])
- τ(chain, draw)float641.341 1.253 0.9012 ... 4.701 3.185
array([[1.34073038, 1.25337506, 0.90120511, ..., 7.71593232, 3.35103068, 4.90820423], [1.38723257, 1.07060803, 1.07632633, ..., 0.62289351, 0.69431774, 0.70327199], [6.62216688, 5.71013263, 5.72604529, ..., 1.50118592, 1.14275532, 1.39495226], [5.17607165, 6.76242035, 7.11944572, ..., 2.30431877, 4.70052024, 3.18480796]])
- θ(chain, draw, school)float643.826 4.778 2.754 ... 5.444 7.008
array([[[ 3.82580815, 4.77820673, 2.75367015, ..., 5.20000611, 3.99858116, 3.3442447 ], [ 3.62539307, 2.66116809, 4.59044101, ..., 1.97466523, 4.06388155, 4.19595353], [ 4.18929573, 4.68331561, 1.42165618, ..., 4.46481228, 2.32853066, 3.97931003], ..., [ 9.57228401, -5.83372229, 4.99980935, ..., -3.09167033, 8.13483873, 10.24229864], [ 7.55986707, 15.27362794, 5.63137058, ..., 12.43570471, 10.19511683, 5.03023247], [ 8.69651861, 11.07514898, 3.34087843, ..., 3.65042769, 3.82075123, 5.36780364]], [[ 6.82596806, 7.3238828 , 6.76661356, ..., 7.87052869, 5.78928781, 9.26406562], [ 2.07072686, 1.59750543, 2.38958599, ..., 2.07977746, 1.43393265, 3.17468462], [ 2.08593606, 1.96540374, 1.79040365, ..., 2.29121475, -0.64260868, 1.95535837], ... [ 2.16458635, 3.37747874, 0.5317364 , ..., 2.81337382, 3.05885857, 2.00957334], [ 5.93003421, 6.43329873, 6.70378137, ..., 5.12755634, 4.73725274, 6.32657518], [ 7.43450593, 6.57152762, 4.67224058, ..., 5.45062127, 6.921223 , 5.38587665]], [[ 4.87781846, 13.12236009, 12.8301631 , ..., 11.02199023, 0.35612231, 5.18124826], [18.39069418, 4.57760771, 5.07768066, ..., 5.23149171, 24.65944989, 13.54127658], [-5.86228011, 16.0856563 , 7.17535026, ..., 10.86542645, 5.61971337, 12.45456404], ..., [ 6.44203777, 3.36532827, 8.41804971, ..., 8.94489362, 6.61320513, 5.32864013], [ 4.33549352, 4.85275073, -0.50413873, ..., 6.70494189, 7.72973489, 7.69134745], [ 8.14187841, 3.10493155, -0.86228044, ..., 7.34052106, 5.444251 , 7.00828561]]])
- created_at :
- 2021-05-11T17:21:19.083858
- arviz_version :
- 0.11.2
- start_time :
- 1620753599.753838
- stop_time :
- 1620753617.297758
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061 τ (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185 θ (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008 Attributes: created_at: 2021-05-11T17:21:19.083858 arviz_version: 0.11.2 start_time: 1620753599.753838 stop_time: 1620753617.297758 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float6437.05 16.89 6.065 ... 14.09 6.818
array([[[ 3.70456077e+01, 1.68932692e+01, 6.06450007e+00, ..., -4.38397462e+00, -6.05941041e+00, -7.97955383e+00], [ 1.44615849e+01, 7.92287172e+00, -5.84871834e+00, ..., -1.25447821e+01, 1.79094926e+01, 2.37342757e+01], [ 3.03042660e+01, 4.90562457e+00, -2.77698425e+00, ..., 8.10522466e+00, 1.02547767e+00, 3.15469821e+01], ..., [ 2.52641815e+01, -2.39768828e+00, 1.32760936e+01, ..., -1.32921395e+00, 3.73148982e+00, 1.42804707e+01], [ 1.22781767e+01, 1.99895167e+01, 2.97326584e+01, ..., 1.46899280e+01, 1.59150320e+01, 5.27225419e-01], [-2.00899029e+01, 4.48428709e+00, -2.87977444e+00, ..., 1.43903324e+01, -2.43693350e+01, -5.88799704e-01]], [[ 8.93576436e+00, -1.30483977e+01, 2.15846751e+01, ..., 1.10174998e+01, 1.81681050e+01, 3.47384277e+00], [ 9.14223024e+00, -9.57601648e+00, 4.72361863e+00, ..., 6.97254403e+00, -2.61195527e+00, -7.03074343e+00], [ 4.18325660e+00, -8.77295313e+00, 2.58142099e+01, ..., -1.23329735e+01, -4.67874781e+00, 5.70556616e+00], ... [-1.11389478e+01, 6.16200539e+00, -1.71413167e+01, ..., 2.44302286e+00, 2.43278907e+00, -4.40307216e+01], [ 1.43525853e+01, 7.80482260e+00, 1.42683442e+00, ..., -5.95829883e-01, 9.59178737e+00, 2.53251504e+01], [ 2.03390228e+01, 1.43708013e+01, 4.68065619e+01, ..., 1.35902737e+01, 6.36280490e+00, -9.90117783e+00]], [[-1.11204810e+01, 2.44864438e+01, 3.25553040e+00, ..., 1.27203171e+01, -6.44770616e+00, 5.31372519e+00], [ 6.83319982e+00, -4.77928462e+00, -1.23289425e+01, ..., 2.33342050e+01, 4.03487254e+01, 4.50547706e+00], [-2.09347182e+01, 2.06371464e+01, 7.93655620e+00, ..., 4.56898732e+01, 9.12929638e+00, 2.62681890e+01], ..., [ 2.28289051e+00, 1.33243154e+01, 3.88433503e+01, ..., 7.39743824e+00, 1.09122198e+01, -3.59951117e+00], [ 2.64192903e+01, -1.05671181e+01, 3.27548378e+01, ..., 1.07727751e+01, 1.25626082e+01, 2.21749981e+01], [ 2.30206226e+01, -1.74784912e+01, -1.49483781e+01, ..., -6.12236825e+00, 1.40894771e+01, 6.81840376e+00]]])
- created_at :
- 2021-05-11T17:21:18.582580
- arviz_version :
- 0.11.2
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 37.05 16.89 6.065 ... 14.09 6.818 Attributes: created_at: 2021-05-11T17:21:18.582580 arviz_version: 0.11.2 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float64-4.926 -3.273 ... -4.01 -3.848
array([[[-4.92563663, -3.27342339, -3.75618491, ..., -3.38972658, -4.20172227, -3.92493082], [-4.94725865, -3.36403926, -3.80405615, ..., -3.32075931, -4.19260061, -3.90329662], [-4.88687682, -3.2765256 , -3.72971289, ..., -3.36644093, -4.44949838, -3.90858725], ..., [-4.38161255, -4.17838299, -3.8165213 , ..., -3.38601466, -3.70813066, -3.81407806], [-4.55543103, -3.48605194, -3.83703616, ..., -3.85722778, -3.52610463, -3.88427581], [-4.45504294, -3.26880633, -3.77005604, ..., -3.34586177, -4.22677911, -3.87718996]], [[-4.62329902, -3.2238093 , -3.87782948, ..., -3.51189234, -3.96703109, -3.82086174], [-5.12104919, -3.42648331, -3.74826092, ..., -3.32165165, -4.59369656, -3.92950503], [-5.11929698, -3.40360539, -3.7363475 , ..., -3.32372321, -4.95925792, -3.96501218], ... [-5.11025229, -3.32836214, -3.7158889 , ..., -3.33042192, -4.33771216, -3.96333595], [-4.70939627, -3.23379639, -3.87544009, ..., -3.38723348, -4.10102595, -3.85898274], [-4.56685439, -3.23172629, -3.80649459, ..., -3.39868517, -3.83522013, -3.87682052]], [[-4.81506713, -3.35271649, -4.18096879, ..., -3.73187632, -4.77805573, -3.88106241], [-3.83218598, -3.28008747, -3.81896656, ..., -3.39082357, -3.44326499, -3.81297624], [-6.17510877, -3.54841282, -3.89374943, ..., -3.71901 , -3.98788111, -3.80962916], ..., [-4.65975704, -3.32892454, -3.94615979, ..., -3.57766577, -3.86981911, -3.877994 ], [-4.87145288, -3.27104952, -3.7036939 , ..., -3.4513229 , -3.74891535, -3.83795919], [-4.50331094, -3.3413321 , -3.70045273, ..., -3.48295863, -4.00975779, -3.84776278]]])
- created_at :
- 2021-05-11T17:21:19.008683
- arviz_version :
- 0.11.2
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 -4.926 -3.273 -3.756 ... -4.01 -3.848 Attributes: created_at: 2021-05-11T17:21:19.008683 arviz_version: 0.11.2 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- energy(chain, draw)float6449.74 48.67 52.26 ... 61.18 57.5
array([[49.7370628 , 48.67232666, 52.25838724, ..., 67.8317717 , 62.75304092, 60.02786579], [51.89725115, 50.65824036, 49.01952976, ..., 52.53422526, 49.98759606, 53.58493908], [62.85460254, 63.96796014, 65.56422743, ..., 54.091911 , 52.40117012, 51.66209205], [63.65406847, 65.15351931, 68.16517778, ..., 60.68044485, 61.18089267, 57.50421029]])
- energy_error(chain, draw)float64-0.1874 0.05797 ... 0.02483
array([[-0.18741771, 0.05796766, 0.15964891, ..., -0.0446108 , -0.02215499, -0.00177725], [ 0.00802054, -0.12422556, 0.06052729, ..., -0.08477852, -0.05075087, 0.00127109], [ 0.01644318, 0.01161982, 0.01507367, ..., -0.18487184, -0.10435632, 0.05729024], [ 0.04659189, -0.02628029, 0.0296357 , ..., 0.16822564, -0.02117395, 0.02483316]])
- tree_depth(chain, draw)int643 3 3 4 5 2 3 3 ... 4 4 4 4 3 4 4 3
array([[3, 3, 3, ..., 5, 5, 4], [5, 5, 3, ..., 5, 4, 3], [3, 5, 4, ..., 3, 4, 3], [5, 5, 5, ..., 4, 4, 3]], dtype=int64)
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- step_size_nom(chain, draw)float640.1222 0.1288 ... 0.1542 0.1542
array([[0.12222895, 0.12882546, 0.12882546, ..., 0.12882546, 0.12882546, 0.12882546], [0.09536174, 0.08711545, 0.08711545, ..., 0.08711545, 0.08711545, 0.08711545], [0.44847013, 0.12865152, 0.12865152, ..., 0.12865152, 0.12865152, 0.12865152], [0.15398524, 0.15422793, 0.15422793, ..., 0.15422793, 0.15422793, 0.15422793]])
- acceptance_rate(chain, draw)float640.9568 0.9265 ... 0.9876 0.991
array([[0.95681006, 0.92652295, 0.89847433, ..., 0.99538872, 0.99996969, 0.99860911], [0.9615556 , 0.99204205, 0.97278553, ..., 0.98857838, 0.93462125, 0.58692603], [0.99522439, 0.98606921, 0.98539153, ..., 0.94774068, 0.96067237, 0.88091718], [0.97263938, 0.99029006, 0.97104278, ..., 0.9115109 , 0.98760612, 0.99095234]])
- log_density(chain, draw)float64-46.53 -46.78 ... -55.97 -54.9
array([[-46.53483252, -46.78134226, -47.77320597, ..., -59.68568187, -57.11756645, -55.23444383], [-47.83856136, -45.60857601, -46.70892238, ..., -46.17375312, -44.74105012, -43.90783267], [-59.08957017, -59.15059099, -63.11189166, ..., -48.26350364, -45.76296103, -46.38602892], [-59.90399995, -61.05616853, -63.72313774, ..., -55.10509419, -55.96714016, -54.90175186]])
- max_energy_error(chain, draw)float64-0.8297 0.1387 ... -0.2598 -0.5436
array([[-0.82967575, 0.13869524, -0.52671218, ..., -0.09622386, -0.04049941, -0.12825035], [-0.20186425, -0.40015155, 0.06052729, ..., -0.51532863, -0.37182162, 1.24694298], [-0.26971113, 0.04007618, -0.30354099, ..., -0.50058604, -2.21065078, 0.52637854], [ 0.05323016, -0.05321337, 0.04958443, ..., 0.26218872, -0.25977127, -0.54362938]])
- is_accept(chain, draw)boolTrue True True ... True True True
array([[ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True]])
- lp(chain, draw)float64-46.53 -46.78 ... -55.97 -54.9
array([[-46.53483252, -46.78134226, -47.77320597, ..., -59.68568187, -57.11756645, -55.23444383], [-47.83856136, -45.60857601, -46.70892238, ..., -46.17375312, -44.74105012, -43.90783267], [-59.08957017, -59.15059099, -63.11189166, ..., -48.26350364, -45.76296103, -46.38602892], [-59.90399995, -61.05616853, -63.72313774, ..., -55.10509419, -55.96714016, -54.90175186]])
- step_size(chain, draw)float640.1222 0.1288 ... 0.1542 0.1542
array([[0.12222895, 0.12882546, 0.12882546, ..., 0.12882546, 0.12882546, 0.12882546], [0.09536174, 0.08711545, 0.08711545, ..., 0.08711545, 0.08711545, 0.08711545], [0.44847013, 0.12865152, 0.12865152, ..., 0.12865152, 0.12865152, 0.12865152], [0.15398524, 0.15422793, 0.15422793, ..., 0.15422793, 0.15422793, 0.15422793]])
- n_steps(chain, draw)int6415 7 7 31 47 7 ... 31 15 15 15 15
array([[15, 7, 7, ..., 31, 31, 31], [47, 39, 7, ..., 63, 15, 7], [15, 31, 31, ..., 15, 31, 7], [31, 31, 31, ..., 15, 15, 15]], dtype=int64)
- created_at :
- 2021-05-11T17:21:19.089630
- arviz_version :
- 0.11.2
- start_time :
- 1620753599.753838
- stop_time :
- 1620753617.297758
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: energy (chain, draw) float64 49.74 48.67 52.26 ... 61.18 57.5 energy_error (chain, draw) float64 -0.1874 0.05797 ... -0.02117 0.02483 tree_depth (chain, draw) int64 3 3 3 4 5 2 3 3 2 ... 4 4 4 4 3 4 4 3 diverging (chain, draw) bool False False False ... False False False step_size_nom (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542 acceptance_rate (chain, draw) float64 0.9568 0.9265 ... 0.9876 0.991 log_density (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9 max_energy_error (chain, draw) float64 -0.8297 0.1387 ... -0.2598 -0.5436 is_accept (chain, draw) bool True True True True ... True True True lp (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9 step_size (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542 n_steps (chain, draw) int64 15 7 7 31 47 7 ... 15 31 15 15 15 15 Attributes: created_at: 2021-05-11T17:21:19.089630 arviz_version: 0.11.2 start_time: 1620753599.753838 stop_time: 1620753617.297758 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float644.591 -3.59 7.95 ... -2.541 -5.66
array([[ 4.59128197e+00, -3.59049788e+00, 7.95026062e+00, 2.97620369e+00, 4.15304835e+00, -1.56851632e-02, 7.06004886e-01, 1.61424499e+00, 9.56224673e+00, -3.80212846e+00, -3.08589532e+00, 2.47023042e+00, 1.44504545e+00, 9.32418487e+00, -4.09958366e+00, 5.99627884e+00, -2.98071493e-01, 4.85677224e+00, 6.14730668e+00, 2.77976384e+00, 5.53429615e+00, 1.21128359e+00, -2.63744096e+00, 2.78311062e+00, -4.00585004e+00, -5.47272343e+00, -1.12369573e+01, 5.91415709e-01, 4.57679713e-01, 4.00621487e+00, -1.34472084e+00, 9.99506182e+00, 3.79629292e+00, 1.87197809e+00, -1.73925003e+00, -1.23439844e+00, 1.46542176e-01, -5.79590674e+00, 2.80453991e+00, -3.64813615e+00, 7.18960349e+00, 3.01457288e+00, -3.71530854e+00, 6.00773666e-01, -2.69393331e+00, 5.65100986e+00, -8.56973796e-01, 4.19121675e+00, 8.33190230e+00, -1.91180729e+00, 5.87081877e+00, 2.99493085e+00, 5.22862035e+00, -2.47399224e+00, 8.10850985e+00, 2.68762730e+00, -5.18493328e+00, 5.91365278e+00, -3.21137170e+00, 7.78951447e+00, ... 7.08193934e-02, 1.12160416e+01, 6.65819185e+00, -1.66191549e+01, 3.22996936e+00, -4.21696566e+00, 2.28453512e+00, 1.18649102e+00, 3.48708358e+00, 2.28998126e+00, 3.98848482e+00, -1.31048654e+00, -1.18928385e+00, -3.34247257e+00, -1.80480947e-01, 4.00615649e-01, 4.12266245e+00, -1.70210840e+00, 1.19917915e+01, -1.03618689e+00, 4.76142069e+00, 1.51736722e+00, 8.99919764e+00, 9.70395402e+00, -5.38469703e+00, -3.43149653e+00, -5.92098688e+00, 1.96466802e+00, 6.61378477e+00, 1.70687526e+00, -2.28794476e+00, -1.42751743e+00, -2.92711330e+00, -3.19807819e+00, -4.94786051e-01, -3.88592061e+00, -1.64633191e-01, 4.32713312e+00, -1.64544946e+00, 6.97281403e+00, 8.59480636e+00, -4.64213906e+00, -3.04710135e+00, -3.67379816e+00, -9.98415750e+00, -4.61034155e+00, -2.72353280e+00, -3.82670451e+00, 2.58965231e+00, -1.21466965e+00, 4.34277178e-01, -3.18626423e+00, 9.23374211e+00, -3.03715580e+00, 4.63716912e+00, -7.28493800e+00, -2.54143663e+00, -5.65959730e+00]])
- τ(chain, draw)float643.726 4.819 0.8977 ... 9.565 7.004
array([[3.72602784e+00, 4.81884993e+00, 8.97733173e-01, 2.39859179e+00, 1.28773605e+00, 1.38676281e+01, 4.98972264e+00, 3.30157022e+00, 6.86887076e+00, 1.91881731e+00, 9.68806040e+00, 6.14419293e-01, 1.14459886e+01, 1.58838854e+00, 2.08436290e+00, 1.52218692e+01, 3.17030462e+00, 1.29364665e+01, 2.01153882e-01, 6.70147145e+00, 1.24382300e+00, 1.14294081e+01, 5.52574588e+00, 8.88630627e+00, 1.90325321e+00, 6.75336297e+00, 4.10157776e+00, 6.03526572e+00, 6.14225388e+00, 5.33484564e+00, 3.71680082e+00, 9.64913889e+00, 4.56437017e-01, 1.74973226e+01, 1.70639352e+02, 1.26131458e+01, 8.24247506e+00, 2.79678929e+00, 3.48045109e+01, 1.58039670e+00, 6.92327527e+01, 1.28863417e+00, 3.66415718e+01, 2.44848640e+01, 1.57052528e+01, 2.50969129e+00, 1.25535526e+00, 1.82629412e+01, 6.82445406e-02, 5.15714502e+00, 4.09524977e-01, 5.32335280e+00, 1.12199478e+00, 1.55585371e-01, 2.79114392e+00, 4.87909955e-01, 4.09842409e+00, 2.63447129e+00, 1.67604497e+01, 1.08017244e+01, 7.69675201e+00, 9.31970008e+00, 6.30299690e+00, 3.81239773e+01, 2.06985909e+00, 2.63934514e+01, 2.39845335e+00, 2.26791140e+00, 3.84554376e+00, 6.58160040e+00, 7.05699998e+00, 3.87081538e-01, 3.41119545e+01, 5.17251063e+00, 1.17095255e+00, 2.14936918e+01, 3.16567059e+00, 1.27347530e+00, 6.25983995e-01, 1.60131005e+01, ... 9.22701157e-01, 4.12910725e+00, 2.81493666e+01, 6.12315829e+00, 6.72312522e+00, 2.53806102e+01, 4.10873852e+00, 6.47063422e+00, 3.23628213e+00, 3.42507662e+00, 1.29337679e+00, 3.66905854e+00, 8.16915390e+00, 1.49970755e+01, 7.32889559e-01, 1.92304151e+00, 1.53739599e+01, 3.54872383e+00, 1.57516045e+02, 3.03690612e+00, 1.62171580e+01, 7.13533113e+00, 1.35969030e+00, 2.86263014e+00, 1.24772063e+00, 1.15592543e+01, 8.16265197e-01, 2.13978962e+00, 3.02809975e+01, 9.84234267e-01, 3.38453171e+01, 3.65845839e-01, 5.41301871e+00, 4.16536559e+00, 1.02145731e+02, 1.14466829e+00, 1.61988902e+01, 4.30893289e+02, 2.54488803e+00, 3.52390958e-01, 2.41683562e+01, 3.51621602e+00, 2.10369584e+00, 7.60672634e-01, 1.58586275e+01, 3.95675303e+01, 3.32972733e+01, 4.03611956e+01, 8.00278225e+00, 2.33875958e+01, 1.18578632e+02, 2.37229850e+00, 7.63755710e+00, 8.27368914e+00, 9.08397524e-01, 8.04427350e-01, 7.67242634e+00, 2.10103802e+01, 7.81080220e+00, 2.26429686e+01, 1.91331688e+00, 1.86402753e+00, 2.29867743e+00, 7.24526987e+00, 9.28525431e+00, 2.08594350e+00, 3.94700041e+01, 3.29835423e+00, 2.66453444e+01, 3.23033080e-01, 5.13079793e+00, 1.46707926e+01, 1.05133905e+00, 5.81640289e+00, 5.65471255e+00, 4.21499589e+00, 3.00271595e+00, 1.35288170e+00, 9.56464214e+00, 7.00435401e+00]])
- θ(chain, draw, school)float642.823 1.1 7.183 ... -2.011 -7.082
array([[[ 2.82314936, 1.10039539, 7.18284546, ..., 5.72668396, 2.05183906, 1.49187027], [ 1.90975985, -4.6071557 , 6.23112117, ..., -11.21196169, -2.33539207, -9.06592627], [ 8.54638091, 8.27098869, 7.6968387 , ..., 7.00056815, 8.69715022, 7.06790238], ..., [ -6.5779797 , -5.97164071, -5.4662512 , ..., -5.83212844, -8.45485585, -7.68013837], [-10.47571615, -14.33015605, -1.39270852, ..., 14.5153871 , 6.43043497, 10.04137878], [-10.84172363, -10.22261348, -11.76901318, ..., -7.77199587, -2.01104341, -7.08175869]]])
- created_at :
- 2021-05-11T17:21:19.854776
- arviz_version :
- 0.11.2
- start_time :
- 1620753652.493586
- stop_time :
- 1620753652.760436
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 4.591 -3.59 7.95 ... -7.285 -2.541 -5.66 τ (chain, draw) float64 3.726 4.819 0.8977 ... 1.353 9.565 7.004 θ (chain, draw, school) float64 2.823 1.1 7.183 ... -2.011 -7.082 Attributes: created_at: 2021-05-11T17:21:19.854776 arviz_version: 0.11.2 start_time: 1620753652.493586 stop_time: 1620753652.760436 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float64-19.08 -8.769 ... 15.77 -8.977
array([[[-19.0846893 , -8.76854093, 22.26135954, ..., 6.24422677, -13.07719443, -16.12241978], [ 7.55661519, 5.13054127, 19.41198552, ..., -14.48417624, -2.7560923 , -7.3074768 ], [ -9.23191875, 16.44205839, 20.76264621, ..., 19.2818767 , 8.48452085, -12.59917092], ..., [ 2.18542675, 6.56640693, -2.94931018, ..., -19.91218368, -2.89339866, -0.84879435], [ -6.46348189, -11.31859887, 30.9233785 , ..., 18.70717281, 14.58648951, 34.27842961], [-40.15430561, -15.83308091, 7.7565488 , ..., -6.66382914, 15.7687312 , -8.97719892]]])
- created_at :
- 2021-05-11T17:21:19.709455
- arviz_version :
- 0.11.2
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 -19.08 -8.769 22.26 ... 15.77 -8.977 Attributes: created_at: 2021-05-11T17:21:19.709455 arviz_version: 0.11.2 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- lp(chain, draw)float64-55.65 -65.73 ... -71.98 -66.78
array([[-5.56529745e+01, -6.57329911e+01, -4.38952031e+01, -5.31535068e+01, -4.91726090e+01, -7.61334412e+01, -6.16372207e+01, -5.56738646e+01, -6.29463549e+01, -5.53808296e+01, -6.73046258e+01, -4.39364836e+01, -7.82075477e+01, -5.21344563e+01, -5.27210062e+01, -7.38509928e+01, -6.01945453e+01, -7.12636533e+01, -3.29611825e+01, -6.19601437e+01, -4.78273286e+01, -8.37823334e+01, -6.28969071e+01, -7.74743053e+01, -5.44797317e+01, -6.54240789e+01, -7.18284104e+01, -6.30744892e+01, -6.15575806e+01, -6.47191015e+01, -5.91905899e+01, -7.07655808e+01, -3.91692483e+01, -9.95454211e+01, -4.13446435e+02, -7.38509873e+01, -7.09619927e+01, -5.99568581e+01, -1.02297556e+02, -5.24911012e+01, -1.33323338e+02, -4.62093305e+01, -9.44512138e+01, -8.28895495e+01, -1.02671855e+02, -5.63612313e+01, -4.94888816e+01, -7.11981639e+01, -2.35146730e+01, -6.11739230e+01, -3.79701631e+01, -5.78523912e+01, -4.73329529e+01, -3.15685595e+01, -5.45136571e+01, -4.09646392e+01, -6.38540760e+01, -5.38385552e+01, -7.49075934e+01, -6.49509713e+01, ... -5.17310804e+01, -5.40780287e+01, -4.77883505e+01, -9.68476240e+01, -4.27113754e+01, -5.66286749e+01, -1.10594800e+02, -4.72260044e+01, -1.21787696e+02, -3.85141508e+01, -5.86691223e+01, -5.99895226e+01, -1.92565413e+02, -4.94912857e+01, -8.49842334e+01, -6.20889471e+03, -5.39858027e+01, -4.09366370e+01, -9.53503248e+01, -5.92166706e+01, -6.00699149e+01, -4.41113355e+01, -8.63946789e+01, -9.64334891e+01, -1.34908886e+02, -1.77652345e+02, -7.22345534e+01, -9.70088348e+01, -3.59262787e+02, -5.10444544e+01, -7.64752232e+01, -6.81652657e+01, -4.66566208e+01, -4.62291203e+01, -6.34491896e+01, -8.93729800e+01, -6.43215540e+01, -8.55891651e+01, -5.34042855e+01, -5.11146232e+01, -5.35001971e+01, -6.60485973e+01, -6.73548383e+01, -5.53512017e+01, -1.86569596e+02, -6.22253178e+01, -1.14181099e+02, -4.33759113e+01, -6.05976552e+01, -8.24161533e+01, -5.13290156e+01, -6.18986569e+01, -6.31529312e+01, -6.09073152e+01, -5.62415222e+01, -5.52647564e+01, -7.19762237e+01, -6.67775662e+01]])
- created_at :
- 2021-05-11T17:21:19.888505
- arviz_version :
- 0.11.2
- start_time :
- 1620753652.493586
- stop_time :
- 1620753652.760436
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 Data variables: lp (chain, draw) float64 -55.65 -65.73 -43.9 ... -55.26 -71.98 -66.78 Attributes: created_at: 2021-05-11T17:21:19.888505 arviz_version: 0.11.2 start_time: 1620753652.493586 stop_time: 1620753652.760436 inference_library: Turing
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-05-11T17:21:21.328132
- arviz_version :
- 0.11.2
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-05-11T17:21:21.328132 arviz_version: 0.11.2 inference_library: Turing
Dataset (xarray.Dataset)
Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:
loo(idata) # higher is better
loo | loo_se | p_loo | n_samples | n_data_points | warning | loo_scale | |
---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Int64 | Int64 | Bool | String | |
1 | -30.707 | 1.3569 | 0.861043 | 4000 | 8 | 0 | log |
If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:
plot_loo_pit(idata; y="y", ecdf=true);
gcf()
Plotting with CmdStan.jl outputs
CmdStan.jl and StanSample.jl also default to producing Chains
outputs, and we can easily plot these chains.
Here is the same centered eight schools model:
using CmdStan, MCMCChains
schools_code = """
data {
int<lower=0> J;
real y[J];
real<lower=0> sigma[J];
}
parameters {
real mu;
real<lower=0> tau;
real theta[J];
}
model {
mu ~ normal(0, 5);
tau ~ cauchy(0, 5);
theta ~ normal(mu, tau);
y ~ normal(theta, sigma);
}
generated quantities {
vector[J] log_lik;
vector[J] y_hat;
for (j in 1:J) {
log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
y_hat[j] = normal_rng(theta[j], sigma[j]);
}
}
"""
schools_dat = Dict("J" => J, "y" => y, "sigma" => σ)
stan_model = Stanmodel(;
model=schools_code,
name="schools",
nchains=nchains,
num_warmup=nwarmup,
num_samples=nsamples,
output_format=:mcmcchains,
random=CmdStan.Random(28983),
)
_, stan_chns, _ = stan(stan_model, schools_dat; summary=false);
File /home/runner/work/ArviZ.jl/ArviZ.jl/docs/build/tmp/schools.stan will be updated.
plot_density(stan_chns; var_names=["mu", "tau"]);
gcf()
Again, converting to InferenceData
, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan
function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains
object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains
.
idata = from_cmdstan(
stan_chns;
posterior_predictive="y_hat",
observed_data=Dict("y" => schools_dat["y"]),
log_likelihood="log_lik",
coords=Dict("school" => schools),
dims=Dict(
"y" => ["school"],
"sigma" => ["school"],
"theta" => ["school"],
"log_lik" => ["school"],
"y_hat" => ["school"],
),
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- theta(chain, draw, school)float646.541 4.165 1.673 ... 7.588 7.146
array([[[ 6.5413 , 4.16455 , 1.67319 , ..., 5.92576 , 7.94342 , 5.17905 ], [ 7.57664 , 2.39794 , 3.42067 , ..., 3.57763 , 12.3478 , 3.96639 ], [ 3.96375 , 5.40868 , 6.68299 , ..., 5.57513 , 0.491807 , 5.66544 ], ..., [10.3071 , 0.740317 , -3.47427 , ..., 4.95825 , 0.219024 , -1.37564 ], [16.1877 , 13.9135 , 2.11515 , ..., 3.52645 , 12.14 , 2.66071 ], [ 4.47998 , 8.5612 , 16.7492 , ..., 6.29477 , 12.4728 , 15.7279 ]], [[-1.53146 , 0.457059 , 1.44009 , ..., -1.01029 , 0.145186 , 0.0664686], [ 1.76084 , -0.543271 , -1.72437 , ..., 1.38095 , 0.291964 , -0.305073 ], [ 1.56649 , -0.430692 , -2.26834 , ..., 0.318468 , 0.692828 , 1.78802 ], ... [ 5.12651 , 3.82547 , 4.35792 , ..., 2.9658 , 5.93826 , 3.39187 ], [ 2.85975 , 4.47044 , 2.9307 , ..., 4.32423 , 2.61794 , 3.82375 ], [ 2.85975 , 4.47044 , 2.9307 , ..., 4.32423 , 2.61794 , 3.82375 ]], [[11.7328 , 12.8994 , 4.04337 , ..., -0.38383 , 3.3298 , 3.99475 ], [-0.0491869, 14.3168 , 2.88353 , ..., 2.27431 , -1.02772 , 0.198564 ], [ 8.98986 , 9.08316 , 4.97468 , ..., 12.3307 , 11.3116 , 6.23245 ], ..., [ 6.95444 , 6.62463 , 7.03101 , ..., 8.27648 , 8.1468 , 7.2166 ], [ 7.84862 , 6.50354 , 8.87397 , ..., 7.77533 , 7.5875 , 7.14576 ], [ 7.84862 , 6.50354 , 8.87397 , ..., 7.77533 , 7.5875 , 7.14576 ]]])
- tau(chain, draw)float642.026 3.44 2.047 ... 0.5771 0.5771
array([[2.02605 , 3.44003 , 2.04688 , ..., 5.83343 , 7.2419 , 4.33761 ], [1.30956 , 1.96716 , 1.36244 , ..., 0.478954, 0.478954, 0.478954], [1.12191 , 1.12191 , 1.65467 , ..., 1.04315 , 1.01915 , 1.01915 ], [5.17901 , 6.3555 , 3.33977 , ..., 0.500199, 0.577086, 0.577086]])
- mu(chain, draw)float645.864 4.849 3.906 ... 7.617 7.617
array([[ 5.8636 , 4.84864 , 3.90624 , ..., 0.748423 , 6.95658 , 7.11618 ], [ 0.453971 , 0.680242 , -0.0159513, ..., 2.14087 , 2.14087 , 2.14087 ], [11.0351 , 11.0351 , 12.0991 , ..., 3.84415 , 3.85201 , 3.85201 ], [ 2.14571 , 5.27973 , 5.252 , ..., 7.3694 , 7.61676 , 7.61676 ]])
- created_at :
- 2021-05-11T17:22:01.518920
- arviz_version :
- 0.11.2
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: theta (chain, draw, school) float64 6.541 4.165 1.673 ... 7.588 7.146 tau (chain, draw) float64 2.026 3.44 2.047 ... 0.5002 0.5771 0.5771 mu (chain, draw) float64 5.864 4.849 3.906 5.24 ... 7.369 7.617 7.617 Attributes: created_at: 2021-05-11T17:22:01.518920 arviz_version: 0.11.2 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y_hat(chain, draw, school)float64-10.71 23.66 -9.425 ... 9.078 20.88
array([[[-10.7089 , 23.6608 , -9.42503 , ..., 1.45558 , 1.90437 , 7.09978 ], [ 4.39818 , -6.80898 , 19.3531 , ..., -2.21286 , 18.276 , 9.09516 ], [ 18.437 , -13.547 , -2.34388 , ..., 2.78722 , 19.8619 , 7.06825 ], ..., [ 5.71133 , 3.27135 , 19.0798 , ..., 4.2393 , -5.22545 , 1.00472 ], [ 15.9129 , 6.10493 , 4.83721 , ..., 5.85394 , 9.80333 , 8.04082 ], [ 37.4436 , 10.7997 , 29.0293 , ..., 1.85626 , 0.618092 , 30.1234 ]], [[ 3.13198 , 8.24575 , 19.1235 , ..., -12.3971 , 2.22067 , -7.00039 ], [ 21.8988 , 7.38319 , 13.3018 , ..., 9.69324 , 1.22257 , -10.9523 ], [ -8.10635 , 1.30664 , 0.8681 , ..., 7.50169 , -18.9755 , -32.4581 ], ... [ 5.93913 , -5.61942 , 16.6463 , ..., 3.54918 , 11.1021 , -7.67133 ], [ -1.97436 , 5.1271 , -0.539521 , ..., 5.91733 , -7.10177 , 27.2203 ], [ 3.88177 , 4.13828 , 4.75519 , ..., -9.62832 , 16.2304 , 36.155 ]], [[ 7.18006 , 8.92322 , -2.87267 , ..., 9.97092 , 11.9656 , -0.704389 ], [ 39.2685 , 10.9865 , 23.5882 , ..., -20.0627 , 5.64123 , 0.988833 ], [ 1.45545 , 19.8549 , -1.86806 , ..., -2.33712 , 5.14261 , -1.42939 ], ..., [ 12.9609 , -0.70491 , -13.0634 , ..., 11.0999 , 0.173286 , 25.4197 ], [ 14.0968 , 13.4943 , 0.659467 , ..., 20.973 , 13.3877 , 37.4606 ], [ 17.3761 , 1.59001 , -12.9362 , ..., 7.11407 , 9.07836 , 20.8776 ]]])
- created_at :
- 2021-05-11T17:22:01.487547
- arviz_version :
- 0.11.2
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y_hat (chain, draw, school) float64 -10.71 23.66 -9.425 ... 9.078 20.88 Attributes: created_at: 2021-05-11T17:22:01.487547 arviz_version: 0.11.2 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- log_lik(chain, draw, school)float64-4.65 -3.295 ... -3.764 -3.846
array([[[-4.65027, -3.29508, -3.73418, ..., -3.41709, -3.7272 , -3.88111], [-4.55391, -3.37844, -3.77204, ..., -3.34429, -3.38126, -3.90891], [-4.91086, -3.2551 , -3.87465, ..., -3.40333, -4.75421, -3.87123], ..., [-4.32263, -3.48504, -3.69197, ..., -3.38158, -4.80234, -4.0854 ], [-3.93706, -3.39637, -3.74263, ..., -3.34321, -3.39322, -3.94391], [-4.8563 , -3.2231 , -4.45331, ..., -3.43268, -3.37427, -3.83076]], [[-5.565 , -3.506 , -3.73003, ..., -3.33353, -4.8155 , -4.02908], [-5.15697, -3.58646, -3.69471, ..., -3.31743, -4.7894 , -4.04298], [-5.17972, -3.57691, -3.69257, ..., -3.31875, -4.71921, -3.97024], ... [-4.78965, -3.30866, -3.79727, ..., -3.3328 , -3.94895, -3.92366], [-5.0315 , -3.28381, -3.76022, ..., -3.3625 , -4.40456, -3.91248], [-5.0315 , -3.28381, -3.76022, ..., -3.3625 , -4.40456, -3.91248]], [[-4.21504, -3.34154, -3.78842, ..., -3.32475, -4.2976 , -3.90821], [-5.37534, -3.42103, -3.75914, ..., -3.32354, -5.03179, -4.02424], [-4.43007, -3.22739, -3.81574, ..., -3.84735, -3.4452 , -3.86064], ..., [-4.61125, -3.23098, -3.88805, ..., -3.53562, -3.70695, -3.84462], [-4.52938, -3.23272, -3.9669 , ..., -3.50652, -3.76362, -3.84567], [-4.52938, -3.23272, -3.9669 , ..., -3.50652, -3.76362, -3.84567]]])
- created_at :
- 2021-05-11T17:22:01.497752
- arviz_version :
- 0.11.2
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: log_lik (chain, draw, school) float64 -4.65 -3.295 -3.734 ... -3.764 -3.846 Attributes: created_at: 2021-05-11T17:22:01.497752 arviz_version: 0.11.2 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- tree_depth(chain, draw)int644 2 4 3 3 3 4 4 ... 1 2 1 2 1 5 3 2
array([[4, 2, 4, ..., 4, 4, 4], [2, 3, 2, ..., 1, 1, 1], [3, 3, 5, ..., 2, 4, 2], [5, 4, 4, ..., 5, 3, 2]], dtype=int64)
- diverging(chain, draw)boolFalse False False ... True False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, True, False], [False, False, False, ..., True, False, False], [False, False, False, ..., False, True, False]])
- energy(chain, draw)float6421.85 17.36 18.12 ... 11.01 16.54
array([[21.8457 , 17.3568 , 18.1216 , ..., 21.8167 , 22.2094 , 24.0544 ], [14.6107 , 14.2824 , 14.2744 , ..., 8.75661, 5.95451, 14.1661 ], [10.8452 , 14.1145 , 15.1179 , ..., 17.5578 , 10.4742 , 10.1383 ], [23.2572 , 25.1984 , 25.1724 , ..., 10.8737 , 11.0074 , 16.5383 ]])
- lp(chain, draw)float64-11.6 -14.53 -12.58 ... -6.09 -6.09
array([[-11.5961 , -14.531 , -12.5757 , ..., -18.5766 , -19.326 , -20.4905 ], [-10.4032 , -10.2195 , -11.5774 , ..., -3.1093 , -3.1093 , -3.1093 ], [ -8.20287, -8.20287, -10.9453 , ..., -10.9182 , -6.2819 , -6.2819 ], [-19.4646 , -21.8295 , -18.5388 , ..., -4.66215, -6.08987, -6.08987]])
- step_size(chain, draw)float640.2012 0.2012 ... 0.1493 0.1493
array([[0.201238, 0.201238, 0.201238, ..., 0.201238, 0.201238, 0.201238], [0.196767, 0.196767, 0.196767, ..., 0.196767, 0.196767, 0.196767], [0.209863, 0.209863, 0.209863, ..., 0.209863, 0.209863, 0.209863], [0.149324, 0.149324, 0.149324, ..., 0.149324, 0.149324, 0.149324]])
- acceptance_rate(chain, draw)float640.9582 0.8976 ... 0.04874 1.587e-05
array([[9.58170e-001, 8.97629e-001, 9.38380e-001, ..., 8.60358e-001, 9.94017e-001, 9.90637e-001], [5.70767e-001, 7.39040e-001, 7.40280e-001, ..., 7.11221e-048, 9.05449e-006, 8.47847e-105], [5.18618e-001, 7.93961e-003, 1.82818e-002, ..., 1.43309e-002, 1.06053e-001, 2.72006e-001], [9.69573e-001, 9.77983e-001, 9.98699e-001, ..., 7.37071e-006, 4.87396e-002, 1.58747e-005]])
- n_steps(chain, draw)int6415 7 15 15 7 15 15 ... 1 5 3 31 9 3
array([[15, 7, 15, ..., 15, 15, 15], [ 7, 7, 3, ..., 1, 3, 1], [ 7, 7, 31, ..., 5, 31, 7], [31, 15, 31, ..., 31, 9, 3]], dtype=int64)
- created_at :
- 2021-05-11T17:22:01.533252
- arviz_version :
- 0.11.2
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: tree_depth (chain, draw) int64 4 2 4 3 3 3 4 4 4 ... 1 1 2 1 2 1 5 3 2 diverging (chain, draw) bool False False False ... False True False energy (chain, draw) float64 21.85 17.36 18.12 ... 11.01 16.54 lp (chain, draw) float64 -11.6 -14.53 -12.58 ... -6.09 -6.09 step_size (chain, draw) float64 0.2012 0.2012 ... 0.1493 0.1493 acceptance_rate (chain, draw) float64 0.9582 0.8976 ... 0.04874 1.587e-05 n_steps (chain, draw) int64 15 7 15 15 7 15 15 ... 5 1 5 3 31 9 3 Attributes: created_at: 2021-05-11T17:22:01.533252 arviz_version: 0.11.2 inference_library: CmdStan
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-05-11T17:22:01.558449
- arviz_version :
- 0.11.2
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-05-11T17:22:01.558449 arviz_version: 0.11.2 inference_library: CmdStan
Dataset (xarray.Dataset)
Here is a plot showing where the Hamiltonian sampler had divergences:
plot_pair(
idata;
coords=Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
divergences=true,
);
gcf()
Plotting with Soss.jl outputs
With Soss, we can define our model for the posterior and easily use it to draw samples from the prior, prior predictive, posterior, and posterior predictive distributions.
First we define our model:
using Soss, NamedTupleTools
mod = Soss.@model (J, σ) begin
μ ~ Normal(0, 5)
τ ~ HalfCauchy(5)
θ ~ iid(J)(Normal(μ, τ))
y ~ For(1:J) do j
Normal(θ[j], σ[j])
end
end
constant_data = (J=J, σ=σ)
param_mod = mod(; constant_data...)
Joint Distribution Bound arguments: [J, σ] Variables: [τ, μ, θ, y] @model (J, σ) begin τ ~ HalfCauchy(5) μ ~ Normal(0, 5) θ ~ (iid(J))(Normal(μ, τ)) y ~ For(1:J) do j Normal(θ[j], σ[j]) end end
Then we draw from the prior and prior predictive distributions.
Random.seed!(5298)
prior_priorpred = [
map(1:(nchains * nsamples)) do _
draw = rand(param_mod)
return delete(draw, keys(constant_data))
end,
];
Next, we draw from the posterior using DynamicHMC.jl.
post = map(1:nchains) do _
dynamicHMC(param_mod, (y=y,), nsamples)
end;
Finally, we update the posterior samples with draws from the posterior predictive distribution.
pred = predictive(mod, :μ, :τ, :θ)
post_postpred = map(post) do post_draws
map(post_draws) do post_draw
pred_draw = rand(pred(post_draw))
pred_draw = delete(pred_draw, keys(constant_data))
return merge(pred_draw, post_draw)
end
end;
Each Soss draw is a NamedTuple
. We can plot the rank order statistics of the posterior to identify poor convergence:
plot_rank(post; var_names=["μ", "τ"]);
gcf()
Now we combine all of the samples to an InferenceData
:
idata = from_namedtuple(
post_postpred;
posterior_predictive=[:y],
prior=prior_priorpred,
prior_predictive=[:y],
observed_data=(y=y,),
constant_data=constant_data,
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library=Soss,
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float644.192 9.221 0.8525 ... 10.42 8.331
array([[ 4.19182982, 9.22113987, 0.85254018, ..., 1.86776446, 4.19976479, 4.29971355], [ 1.63547907, 2.25227435, 5.44924084, ..., 6.00139665, 2.17617079, 10.21207367], [ 2.17480272, 3.20765337, 7.15100062, ..., -0.8345852 , 3.56145036, 6.0074083 ], [ 2.9368254 , 3.10958053, 3.63928633, ..., 10.62210571, 10.42001632, 8.33075847]])
- τ(chain, draw)float6411.95 12.96 15.19 ... 1.388 1.072
array([[11.94703682, 12.96489744, 15.18626218, ..., 17.19004072, 8.56693901, 4.8251171 ], [ 3.44201317, 7.09518016, 7.52417769, ..., 8.15940226, 7.17093961, 4.64397913], [ 2.42246211, 1.99745362, 3.88594351, ..., 3.93967337, 6.08224859, 21.77940294], [ 3.89338614, 2.92854445, 3.2928606 , ..., 1.13504955, 1.38806224, 1.07165971]])
- θ(chain, draw, school)float6432.52 -0.5799 8.606 ... 7.977 6.837
array([[[ 32.5214633 , -0.57994193, 8.6057319 , ..., 10.01938786, 17.91284666, -0.17930885], [ 39.92180136, 4.4843573 , 1.71818452, ..., 6.57472456, 16.88784115, 29.68247154], [ 13.33314418, 25.11441699, 18.74168658, ..., 9.56342451, 25.98682852, -12.24490228], ..., [ -8.09989152, 3.87360152, -10.19557186, ..., 4.12025215, 21.97734457, -14.82153989], [ 22.30214838, 4.26932715, 16.19649874, ..., -2.9155601 , -5.16348754, 20.29481947], [ 13.1757698 , 2.41870509, -0.99485278, ..., 4.72635976, -1.43217089, 6.82665946]], [[ 4.07485519, 1.96613753, -1.06242478, ..., 0.11739371, 6.66118326, -4.56765937], [ 5.7580047 , 2.88722936, 6.1587841 , ..., -6.56270912, 0.33122731, 10.66982962], [ 6.01547595, 10.53440493, -4.98086496, ..., 4.56757256, 12.62123818, -0.29306062], ... [ -5.14684247, 1.27526355, -4.99840668, ..., -0.06287216, -1.56959899, 5.48906162], [ 11.4342392 , 5.48565465, -11.35366228, ..., 0.78919497, 20.36705574, 9.28803298], [ 23.92225627, 12.1914445 , -8.84902117, ..., 4.0638209 , 14.18901249, 24.62229623]], [[ 7.6466768 , -0.88833288, 1.5655832 , ..., 1.19381118, 7.52759707, 4.45000355], [ 7.7324161 , 1.26247769, 3.25393247, ..., 3.19843331, 5.20006538, -0.13077683], [ 3.99078707, 4.12771502, 5.21978841, ..., 7.23168713, 12.12180826, 5.8407282 ], ..., [ 11.53854631, 8.65206815, 10.29186108, ..., 10.64355716, 11.57479127, 11.1684872 ], [ 10.25504288, 9.18603762, 9.78026879, ..., 9.19556514, 10.4572234 , 10.90663567], [ 7.9497517 , 10.01407749, 8.20034753, ..., 8.43825547, 7.97688928, 6.83722618]]])
- created_at :
- 2021-05-11T17:23:14.210574
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 4.192 9.221 0.8525 ... 10.62 10.42 8.331 τ (chain, draw) float64 11.95 12.96 15.19 7.62 ... 1.135 1.388 1.072 θ (chain, draw, school) float64 32.52 -0.5799 8.606 ... 7.977 6.837 Attributes: created_at: 2021-05-11T17:23:14.210574 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float6417.84 1.985 15.04 ... 2.03 14.85
array([[[ 1.78392082e+01, 1.98536103e+00, 1.50412989e+01, ..., 1.52425922e+01, 1.02634403e+01, -4.98747652e+00], [ 2.47520990e+01, 1.03804101e+01, -5.11930770e+00, ..., 2.49401334e+01, 2.70547690e+01, 7.71487682e+01], [-4.15002011e+00, 2.34254238e+01, 2.71041288e+00, ..., -6.53294678e+00, 8.84880044e+00, -9.56024107e+00], ..., [-1.75354521e+01, 8.94666816e+00, -4.95190276e+00, ..., 1.08145763e+01, 3.36836541e+01, -1.80876223e+01], [ 1.43234635e+01, 1.33305556e+00, -7.00144484e+00, ..., -9.02080043e+00, 3.93770424e+00, 1.58102101e+01], [ 3.15975440e+01, -4.79000323e+00, -1.62838272e+01, ..., 2.07598394e+01, -1.49664798e+00, -5.21590943e-01]], [[ 2.00692070e+01, -3.64900580e-02, -6.59382796e+00, ..., 7.96591974e-01, 3.91009498e+00, -2.31989673e+01], [ 6.83677225e-01, 3.95972877e+00, -1.34566469e+00, ..., -1.79693747e+01, 5.05641591e-01, -1.48969335e+01], [-3.95647741e+00, -2.78774198e+00, -1.51847811e+01, ..., 1.48980266e+01, 1.30709774e+01, -2.84659716e+01], ... [-7.50057403e-01, -9.93602676e+00, -1.91311857e+01, ..., -1.89766190e+01, -2.04050046e+01, 1.20906079e+01], [-3.13870845e+00, 1.01950963e+01, 4.23653952e+01, ..., -1.30590796e+01, 1.69177326e+01, 2.09275184e+01], [ 1.30820510e+01, 1.84422367e+01, -1.99573847e+01, ..., -1.61295949e+00, 1.12704208e+01, 7.78979478e+00]], [[ 2.68849717e+01, -7.09458874e-01, -2.80457006e+01, ..., 2.03955493e+01, -2.72719592e+00, -1.19884905e+01], [ 3.62902002e+01, 7.11889100e+00, -4.27069166e+00, ..., -4.18574946e+00, 3.38933578e+00, 3.04372167e+01], [ 1.07132998e+01, -5.55721185e+00, -2.49478026e+01, ..., -1.23604546e-01, 2.26040703e+01, -2.79746115e+00], ..., [ 2.20960343e+01, 8.47871521e+00, 2.40551640e+01, ..., 2.53413594e+01, 2.31179558e+01, 2.01219642e+01], [ 1.35337715e+01, 1.52265495e+01, 9.26844400e+00, ..., -2.26346050e+00, 3.31050091e+00, 5.35054408e+01], [-3.83918986e+00, 4.79474317e+00, 2.18971614e+01, ..., 2.49886525e+01, 2.02956355e+00, 1.48514131e+01]]])
- created_at :
- 2021-05-11T17:23:14.172903
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 17.84 1.985 15.04 ... 2.03 14.85 Attributes: created_at: 2021-05-11T17:23:14.172903 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 1
- draw: 4000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 ... 3996 3997 3998 3999
array([ 0, 1, 2, ..., 3997, 3998, 3999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float64-1.058 6.136 ... -10.82 -0.7505
array([[ -1.05761955, 6.13590571, -0.38523938, ..., 1.48448254, -10.82011143, -0.75053337]])
- τ(chain, draw)float64149.7 0.02342 3.398 ... 84.28 4.543
array([[1.49711440e+02, 2.34168519e-02, 3.39824855e+00, ..., 4.34685582e+00, 8.42779009e+01, 4.54300278e+00]])
- θ(chain, draw, school)float64-8.511 154.9 ... -2.911 -6.732
array([[[ -8.51103855, 154.89324528, 44.86132777, ..., 51.29420022, 38.10463115, 74.88003827], [ 6.11829496, 6.1144437 , 6.09869911, ..., 6.10864983, 6.18038946, 6.15285673], [ 3.54669756, 0.93706638, -3.76834063, ..., 2.9794253 , 2.34742332, 2.82124655], ..., [ -4.16283704, 6.00855201, 2.00872495, ..., -0.22008474, 10.33295411, -1.18006416], [ 18.43190008, 16.84979632, -61.57858988, ..., 24.19601952, -28.84201014, 27.37869771], [ -0.29455833, 5.27378996, -2.10237174, ..., -9.3180781 , -2.91131027, -6.73238936]]])
- created_at :
- 2021-05-11T17:23:15.317494
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 4000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 -1.058 6.136 -0.3852 ... 1.484 -10.82 -0.7505 τ (chain, draw) float64 149.7 0.02342 3.398 ... 4.347 84.28 4.543 θ (chain, draw, school) float64 -8.511 154.9 44.86 ... -2.911 -6.732 Attributes: created_at: 2021-05-11T17:23:15.317494 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 1
- draw: 4000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 ... 3996 3997 3998 3999
array([ 0, 1, 2, ..., 3997, 3998, 3999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float64-19.44 137.9 18.57 ... 5.947 13.3
array([[[-1.94418447e+01, 1.37915819e+02, 1.85671660e+01, ..., 5.57225345e+01, 3.01545494e+01, 7.93412175e+01], [ 9.90314022e+00, 2.77866851e+00, 2.37646213e+01, ..., 4.00280550e+00, 2.01330151e+01, 2.35880674e+00], [ 3.51861237e+01, -5.10963889e+00, 6.33455270e+00, ..., 1.30060388e+01, 9.25869065e+00, -2.23306866e+00], ..., [ 2.51171200e+01, 4.42316555e+00, 2.11349881e+01, ..., 9.05110605e-02, 1.44065339e+01, 8.20843537e-02], [ 2.71918187e+01, 2.16938885e+01, -5.72348538e+01, ..., 2.12376840e+01, -2.51280109e+01, 2.18291771e+01], [ 1.00724253e+01, 1.75265926e+01, -4.43913067e+00, ..., -7.83367531e+00, 5.94749930e+00, 1.33033673e+01]]])
- created_at :
- 2021-05-11T17:23:15.314620
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 4000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 -19.44 137.9 18.57 ... 5.947 13.3 Attributes: created_at: 2021-05-11T17:23:15.314620 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-05-11T17:23:16.068613
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-05-11T17:23:16.068613 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset) -
- J_dim_0: 1
- school: 8
- J_dim_0(J_dim_0)int640
array([0], dtype=int64)
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- J(J_dim_0)int648
array([8], dtype=int64)
- σ(school)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2021-05-11T17:23:17.894044
- arviz_version :
- 0.11.2
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (J_dim_0: 1, school: 8) Coordinates: * J_dim_0 (J_dim_0) int64 0 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: J (J_dim_0) int64 8 σ (school) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2021-05-11T17:23:17.894044 arviz_version: 0.11.2 inference_library: Soss
Dataset (xarray.Dataset)
We can compare the prior and posterior predictive distributions:
plot_density(
[idata.posterior_predictive, idata.prior_predictive];
data_labels=["Post-pred", "Prior-pred"],
var_names=["y"],
)
gcf()
Environment
using Pkg
Pkg.status()
using InteractiveUtils
versioninfo()
Julia Version 1.6.1 Commit 6aaedecc44 (2021-04-23 05:59 UTC) Platform Info: OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-11.0.1 (ORCJIT, broadwell) Environment: JULIA_CMDSTAN_HOME = /home/runner/work/ArviZ.jl/ArviZ.jl/.cmdstan//cmdstan-2.25.0/ JULIA_NUM_THREADS = 2