ArviZ.jl Quickstart
This tutorial is adapted from ArviZ's quickstart.
using ArviZ
using PyPlot
# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")
Get started with plotting
ArviZ.jl is designed to be used with libraries like CmdStan, Turing.jl, and Soss.jl but works fine with raw arrays.
using Random
rng = Random.MersenneTwister(37772)
plot_posterior(randn(rng, 100_000));
gcf()
Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.
using Distributions
s = (10, 50)
plot_forest(
Dict(
"normal" => randn(rng, s),
"gumbel" => rand(rng, Gumbel(), s),
"student t" => rand(rng, TDist(6), s),
"exponential" => rand(rng, Exponential(), s),
),
);
gcf()
Plotting with MCMCChains.jl's Chains
objects produced by Turing.jl
ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.
This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).
First we create our data and set some sampling parameters.
J = 8
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
schools = [
"Choate",
"Deerfield",
"Phillips Andover",
"Phillips Exeter",
"Hotchkiss",
"Lawrenceville",
"St. Paul's",
"Mt. Hermon",
];
nwarmup, nsamples, nchains = 1000, 1000, 4;
Now we write and run the model using Turing:
using Turing
Turing.@model function turing_model(J, y, σ, ::Type{TV}=Vector{Float64}) where {TV}
begin
μ ~ Normal(0, 5)
τ ~ truncated(Cauchy(0, 5), 0, Inf)
θ = TV(undef, J)
θ .~ Normal(μ, τ)
for i in eachindex(y)
y[i] ~ Normal(θ[i], σ[i])
end
return y
end
end
param_mod = turing_model(J, y, σ)
sampler = NUTS(nwarmup, 0.8)
rng = Random.MersenneTwister(16653)
turing_chns = sample(
rng, param_mod, sampler, MCMCThreads(), nwarmup + nsamples, nchains; progress=false
);
┌ Info: Found initial step size └ ϵ = 1.6
Most ArviZ functions work fine with Chains
objects from Turing:
plot_autocorr(turing_chns; var_names=["μ", "τ"]);
gcf()
Convert to InferenceData
For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains
objects to xarray datasets. Note we are also giving some information about labelling.
ArviZ is built to work with InferenceData
(a netcdf datastore that loads data into xarray
datasets), and the more groups it has access to, the more powerful analyses it can perform.
idata = from_mcmcchains(
turing_chns;
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float642.242 7.977 5.731 ... 2.973 7.224
array([[ 2.24160829, 7.97734147, 5.73135114, ..., 6.12018333, 9.53684758, 3.0869606 ], [ 7.70653513, 7.5573304 , 6.62725568, ..., 8.3006838 , 8.3006838 , 8.3006838 ], [ 6.04538605, 5.91252989, 5.91252989, ..., 5.97777649, 6.45606721, 4.79280973], [-0.58799583, 4.90726082, 3.28986596, ..., 4.0850749 , 2.97250336, 7.22432856]])
- τ(chain, draw)float645.526 6.236 4.726 ... 1.253 2.106
array([[ 5.52591161, 6.23563268, 4.72612647, ..., 9.87097099, 8.8219826 , 17.58451073], [ 3.55699708, 6.73662161, 5.81125028, ..., 0.97136651, 0.97136651, 0.97136651], [ 0.43531242, 0.44324977, 0.44324977, ..., 2.5246634 , 2.38950769, 2.27977944], [ 7.61182381, 2.90622541, 2.70740599, ..., 2.24902501, 1.252966 , 2.1055964 ]])
- θ(chain, draw, school)float644.309 9.479 10.76 ... 8.703 8.909
array([[[ 4.30877987, 9.4794789 , 10.76416115, ..., -2.33705503, 7.92203023, -0.09939932], [13.60428579, 4.44158389, 1.6507143 , ..., 11.9615949 , 10.65045212, 12.90262932], [10.63814285, 11.87368711, 14.07732561, ..., 0.20207278, 10.98781336, 5.15780952], ..., [ 4.38811174, 15.16862573, 10.64180499, ..., 6.66538335, 11.62484015, 11.23178674], [-0.38262679, 18.05730492, 10.36870154, ..., -6.7560735 , 12.55728159, 15.57971579], [36.43114198, -1.24230296, -0.70212953, ..., 17.85756248, 21.79980835, 34.04836302]], [[ 4.27332986, 7.39208439, 5.85803122, ..., 14.67141202, 3.21629742, 13.14755757], [ 5.99380625, 5.10774998, -1.96035688, ..., 9.58305306, 10.71001478, 8.5298023 ], [ 7.63586201, 5.90108043, -2.00454104, ..., 7.47147537, 9.85714863, 7.28028336], ... [ 1.73062512, 9.74873023, 9.74717951, ..., 6.07453874, 8.36211683, 5.16644616], [ 2.92727608, 8.74696468, 8.79441179, ..., 6.08095603, 8.56338567, 6.31367838], [10.27793469, 4.45182344, 6.49537154, ..., 3.73623045, 5.22337011, 2.97712211]], [[ 1.16726342, -1.35914189, -3.39638863, ..., 2.67563877, 1.63211114, 6.06372414], [ 5.18270253, 7.15130497, 2.25004445, ..., 2.60926594, 5.06487665, 7.49008145], [ 3.38510145, 1.89417128, 3.11468045, ..., 9.35274262, 0.94524826, 5.52889926], ..., [ 3.49818772, 3.70674944, 6.27065453, ..., 3.15192986, 0.17280284, 1.84192382], [ 4.82664672, 5.56640453, 3.50106543, ..., 3.43926856, 4.41246824, 2.59073372], [ 7.21027734, 5.77350291, 7.65241325, ..., 8.38882264, 8.70252592, 8.90851363]]])
- created_at :
- 2021-01-20T07:27:15.683692
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 2.242 7.977 5.731 5.574 ... 4.085 2.973 7.224 τ (chain, draw) float64 5.526 6.236 4.726 6.596 ... 2.249 1.253 2.106 θ (chain, draw, school) float64 4.309 9.479 10.76 ... 8.703 8.909 Attributes: created_at: 2021-01-20T07:27:15.683692 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- energy(chain, draw)float6464.14 65.4 63.17 ... 54.05 54.95
array([[64.13639067, 65.39997463, 63.16767838, ..., 69.01042882, 70.88577962, 75.42543421], [65.56948072, 64.86390545, 59.36188167, ..., 47.34652191, 46.29628436, 46.48431619], [46.43022352, 44.5069777 , 44.68840271, ..., 64.47975428, 55.60878398, 56.6490265 ], [63.69664605, 63.29229268, 59.0733094 , ..., 56.77057146, 54.04838171, 54.95284939]])
- energy_error(chain, draw)float64-0.02879 0.04508 ... -0.05148
array([[-0.02878785, 0.04508437, 0.04398762, ..., 0.25193447, -0.13875936, 0.0248281 ], [ 0.04114259, -0.07210876, -0.03569514, ..., 0. , 0. , 0. ], [ 0. , 1.38714563, 0. , ..., -0.119348 , -0.03731234, -0.00836846], [ 0.11732562, -0.12819408, -0.02658204, ..., 0.16327251, 0.01176959, -0.05147846]])
- depth(chain, draw)int644 4 4 4 4 4 4 4 ... 3 4 4 3 4 4 3 5
array([[4, 4, 4, ..., 4, 4, 5], [3, 3, 3, ..., 4, 3, 2], [2, 2, 2, ..., 5, 4, 4], [5, 5, 4, ..., 4, 3, 5]], dtype=int64)
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, True], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- log_density(chain, draw)float64-59.74 -59.23 ... -48.34 -49.35
array([[-59.74160388, -59.2265105 , -59.03096837, ..., -61.85483472, -65.88874085, -70.76454453], [-61.71594381, -58.74261962, -57.36217811, ..., -43.89145643, -43.89145643, -43.89145643], [-39.85221113, -43.33179921, -43.33179921, ..., -54.1321832 , -51.73064442, -51.45793823], [-59.17600693, -51.6529786 , -53.72093712, ..., -51.76150662, -48.33558787, -49.34883006]])
- max_energy_error(chain, draw)float64-0.03812 0.08593 ... -0.6883
array([[-3.81178848e-02, 8.59268443e-02, 7.11018702e-02, ..., 4.21680050e-01, -5.62221150e-01, 2.96399569e-02], [ 4.08600945e-01, -2.68954786e-01, -1.02665619e-01, ..., 3.09650231e+02, 3.96431593e+02, 3.69852292e+03], [ 2.26605542e+02, 4.57404023e+01, 1.18579914e+02, ..., 1.27918962e-01, -2.59461110e-01, 2.64874705e-02], [ 2.52047847e-01, -2.06878158e-01, 3.77250974e-01, ..., 3.33315681e-01, -7.24359721e-01, -6.88308151e-01]])
- nom_step_size(chain, draw)float640.2857 0.2857 ... 0.09645 0.09645
array([[0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896]])
- is_accept(chain, draw)float641.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
array([[1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.]])
- tree_size(chain, draw)int6415 15 15 15 15 ... 15 15 15 15 47
array([[15, 15, 15, ..., 31, 15, 31], [ 7, 7, 7, ..., 15, 9, 5], [ 3, 3, 3, ..., 31, 15, 31], [31, 63, 15, ..., 15, 15, 47]], dtype=int64)
- lp(chain, draw)float64-59.74 -59.23 ... -48.34 -49.35
array([[-59.74160388, -59.2265105 , -59.03096837, ..., -61.85483472, -65.88874085, -70.76454453], [-61.71594381, -58.74261962, -57.36217811, ..., -43.89145643, -43.89145643, -43.89145643], [-39.85221113, -43.33179921, -43.33179921, ..., -54.1321832 , -51.73064442, -51.45793823], [-59.17600693, -51.6529786 , -53.72093712, ..., -51.76150662, -48.33558787, -49.34883006]])
- step_size(chain, draw)float640.2857 0.2857 ... 0.09645 0.09645
array([[0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896]])
- mean_tree_accept(chain, draw)float640.9954 0.9501 ... 0.959 0.9753
array([[0.995424 , 0.95007706, 0.97187014, ..., 0.82397683, 0.90121431, 0.98599513], [0.88492336, 0.99927971, 1. , ..., 0.00135985, 0.01078467, 0.02874953], [0.01452575, 0.08348777, 0.04816339, ..., 0.9605129 , 0.98682368, 0.99467091], [0.91518714, 0.99255986, 0.91560456, ..., 0.82760414, 0.95902893, 0.97527452]])
- created_at :
- 2021-01-20T07:27:15.717875
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: energy (chain, draw) float64 64.14 65.4 63.17 ... 54.05 54.95 energy_error (chain, draw) float64 -0.02879 0.04508 ... -0.05148 depth (chain, draw) int64 4 4 4 4 4 4 4 4 4 ... 3 4 4 3 4 4 3 5 diverging (chain, draw) bool False False False ... False False False log_density (chain, draw) float64 -59.74 -59.23 ... -48.34 -49.35 max_energy_error (chain, draw) float64 -0.03812 0.08593 ... -0.7244 -0.6883 nom_step_size (chain, draw) float64 0.2857 0.2857 ... 0.09645 0.09645 is_accept (chain, draw) float64 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 tree_size (chain, draw) int64 15 15 15 15 15 15 ... 15 15 15 15 47 lp (chain, draw) float64 -59.74 -59.23 ... -48.34 -49.35 step_size (chain, draw) float64 0.2857 0.2857 ... 0.09645 0.09645 mean_tree_accept (chain, draw) float64 0.9954 0.9501 ... 0.959 0.9753 Attributes: created_at: 2021-01-20T07:27:15.717875 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset)
Each group is an ArviZ.Dataset
(a thinly wrapped xarray.Dataset
). We can view a summary of the dataset.
idata.posterior
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 2.242 7.977 5.731 5.574 ... 4.085 2.973 7.224 τ (chain, draw) float64 5.526 6.236 4.726 6.596 ... 2.249 1.253 2.106 θ (chain, draw, school) float64 4.309 9.479 10.76 ... 8.703 8.909 Attributes: created_at: 2021-01-20T07:27:15.683692 arviz_version: 0.11.0 inference_library: Turing
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float642.242 7.977 5.731 ... 2.973 7.224
array([[ 2.24160829, 7.97734147, 5.73135114, ..., 6.12018333, 9.53684758, 3.0869606 ], [ 7.70653513, 7.5573304 , 6.62725568, ..., 8.3006838 , 8.3006838 , 8.3006838 ], [ 6.04538605, 5.91252989, 5.91252989, ..., 5.97777649, 6.45606721, 4.79280973], [-0.58799583, 4.90726082, 3.28986596, ..., 4.0850749 , 2.97250336, 7.22432856]])
- τ(chain, draw)float645.526 6.236 4.726 ... 1.253 2.106
array([[ 5.52591161, 6.23563268, 4.72612647, ..., 9.87097099, 8.8219826 , 17.58451073], [ 3.55699708, 6.73662161, 5.81125028, ..., 0.97136651, 0.97136651, 0.97136651], [ 0.43531242, 0.44324977, 0.44324977, ..., 2.5246634 , 2.38950769, 2.27977944], [ 7.61182381, 2.90622541, 2.70740599, ..., 2.24902501, 1.252966 , 2.1055964 ]])
- θ(chain, draw, school)float644.309 9.479 10.76 ... 8.703 8.909
array([[[ 4.30877987, 9.4794789 , 10.76416115, ..., -2.33705503, 7.92203023, -0.09939932], [13.60428579, 4.44158389, 1.6507143 , ..., 11.9615949 , 10.65045212, 12.90262932], [10.63814285, 11.87368711, 14.07732561, ..., 0.20207278, 10.98781336, 5.15780952], ..., [ 4.38811174, 15.16862573, 10.64180499, ..., 6.66538335, 11.62484015, 11.23178674], [-0.38262679, 18.05730492, 10.36870154, ..., -6.7560735 , 12.55728159, 15.57971579], [36.43114198, -1.24230296, -0.70212953, ..., 17.85756248, 21.79980835, 34.04836302]], [[ 4.27332986, 7.39208439, 5.85803122, ..., 14.67141202, 3.21629742, 13.14755757], [ 5.99380625, 5.10774998, -1.96035688, ..., 9.58305306, 10.71001478, 8.5298023 ], [ 7.63586201, 5.90108043, -2.00454104, ..., 7.47147537, 9.85714863, 7.28028336], ... [ 1.73062512, 9.74873023, 9.74717951, ..., 6.07453874, 8.36211683, 5.16644616], [ 2.92727608, 8.74696468, 8.79441179, ..., 6.08095603, 8.56338567, 6.31367838], [10.27793469, 4.45182344, 6.49537154, ..., 3.73623045, 5.22337011, 2.97712211]], [[ 1.16726342, -1.35914189, -3.39638863, ..., 2.67563877, 1.63211114, 6.06372414], [ 5.18270253, 7.15130497, 2.25004445, ..., 2.60926594, 5.06487665, 7.49008145], [ 3.38510145, 1.89417128, 3.11468045, ..., 9.35274262, 0.94524826, 5.52889926], ..., [ 3.49818772, 3.70674944, 6.27065453, ..., 3.15192986, 0.17280284, 1.84192382], [ 4.82664672, 5.56640453, 3.50106543, ..., 3.43926856, 4.41246824, 2.59073372], [ 7.21027734, 5.77350291, 7.65241325, ..., 8.38882264, 8.70252592, 8.90851363]]])
- created_at :
- 2021-01-20T07:27:15.683692
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Here is a plot of the trace. Note the intelligent labels.
plot_trace(idata);
gcf()
We can also generate summary stats
summarystats(idata)
variable | mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_mean | ess_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
String | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | μ | 4.516 | 3.432 | -1.244 | 11.264 | 0.191 | 0.154 | 323.0 | 248.0 | 345.0 | 174.0 | 1.01 |
2 | τ | 3.761 | 3.245 | 0.315 | 9.546 | 0.266 | 0.189 | 148.0 | 148.0 | 66.0 | 27.0 | 1.05 |
3 | θ[1] | 6.354 | 5.878 | -4.06 | 16.889 | 0.261 | 0.185 | 507.0 | 507.0 | 449.0 | 1436.0 | 1.01 |
4 | θ[2] | 5.227 | 4.875 | -3.001 | 14.861 | 0.216 | 0.17 | 511.0 | 410.0 | 492.0 | 1054.0 | 1.01 |
5 | θ[3] | 4.03 | 5.326 | -5.698 | 13.892 | 0.196 | 0.156 | 741.0 | 584.0 | 643.0 | 1421.0 | 1.0 |
6 | θ[4] | 4.846 | 4.855 | -4.4 | 13.445 | 0.21 | 0.149 | 534.0 | 534.0 | 495.0 | 1570.0 | 1.01 |
7 | θ[5] | 3.615 | 4.831 | -5.744 | 11.948 | 0.201 | 0.16 | 576.0 | 455.0 | 514.0 | 1079.0 | 1.0 |
8 | θ[6] | 4.157 | 5.059 | -5.959 | 13.44 | 0.191 | 0.157 | 698.0 | 521.0 | 596.0 | 1135.0 | 1.0 |
9 | θ[7] | 6.559 | 5.07 | -2.274 | 15.76 | 0.271 | 0.192 | 350.0 | 350.0 | 344.0 | 1419.0 | 1.01 |
10 | θ[8] | 4.913 | 5.657 | -4.391 | 16.094 | 0.21 | 0.149 | 726.0 | 726.0 | 590.0 | 1497.0 | 1.01 |
and examine the energy distribution of the Hamiltonian sampler
plot_energy(idata);
gcf()
Additional information in Turing.jl
With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData
.
To sample from the prior, one simply calls sample
but with the Prior
sampler:
prior = sample(param_mod, Prior(), nsamples; progress=false)
Chains MCMC chain (1000×11×1 Array{Float64,3}): Iterations = 1:1000 Thinning interval = 1 Chains = 1 Samples per chain = 1000 parameters = θ[1], θ[2], θ[3], θ[4], θ[5], θ[6], θ[7], θ[8], μ, τ internals = lp Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 θ[1] 0.9437 67.9184 2.1478 3.0807 556.5721 1.0001 θ[2] -0.6113 48.9282 1.5472 1.1096 1023.9913 1.0000 θ[3] -1.1874 60.8370 1.9238 1.4449 999.5571 0.9991 θ[4] 0.1588 52.6023 1.6634 2.0985 812.8250 1.0025 θ[5] 1.7428 71.2371 2.2527 2.2148 918.9025 0.9993 θ[6] 2.0271 74.2805 2.3490 2.8746 1004.2599 1.0024 θ[7] 1.1439 53.3499 1.6871 0.9483 1044.3923 0.9991 θ[8] 0.7388 58.6210 1.8538 1.8743 964.9287 1.0012 μ 0.0436 4.8547 0.1535 0.1428 1027.7394 1.0001 τ 18.2513 68.3494 2.1614 3.0856 638.1239 0.9996 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 θ[1] -53.1021 -5.5494 0.0134 5.4924 42.0920 θ[2] -47.9095 -4.9937 0.2045 6.0771 38.2048 θ[3] -48.4412 -5.0144 0.2915 5.8051 39.0567 θ[4] -36.2510 -4.9584 0.4698 6.0296 46.7597 θ[5] -41.1661 -5.4494 -0.3852 5.0671 53.3500 θ[6] -40.9711 -5.4135 0.2146 5.5908 47.2051 θ[7] -39.5905 -5.2035 0.3970 6.4169 50.1237 θ[8] -49.2875 -4.9863 0.2679 5.8763 55.9611 μ -9.7057 -3.3430 -0.0990 3.3103 9.5634 τ 0.1994 2.0780 4.9683 11.2756 112.9419
To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing
, and then calling predict
on the predictive model and the previously drawn samples:
# Instantiate the predictive model
param_mod_predict = turing_model(J, similar(y, Missing), σ)
# and then sample!
prior_predictive = predict(param_mod_predict, prior)
posterior_predictive = predict(param_mod_predict, turing_chns)
Chains MCMC chain (1000×8×4 Array{Float64,3}): Iterations = 1:1000 Thinning interval = 1 Chains = 1, 2, 3, 4 Samples per chain = 1000 parameters = y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8] internals = Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 y[1] 4.5082 16.0150 0.2532 0.2930 3268.5480 1.0007 y[2] 4.6493 11.9315 0.1887 0.3013 2426.6130 1.0007 y[3] 4.5704 17.2858 0.2733 0.3852 3246.5857 1.0000 y[4] 4.6143 12.6587 0.2002 0.2960 3443.9512 1.0005 y[5] 4.4985 10.6522 0.1684 0.2672 1895.2847 1.0018 y[6] 4.7272 12.4434 0.1967 0.2517 3152.6502 1.0012 y[7] 4.3443 11.7323 0.1855 0.2427 3064.7075 1.0001 y[8] 4.9657 18.9625 0.2998 0.3368 3958.3019 1.0001 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 y[1] -26.7190 -6.0134 4.6518 15.1441 35.9749 y[2] -18.8314 -3.2552 4.5512 12.5614 27.9407 y[3] -29.6524 -6.9376 4.6521 16.0361 38.4275 y[4] -20.5275 -3.5383 4.7111 12.8250 29.4160 y[5] -15.8879 -2.7019 4.7310 11.4978 25.5380 y[6] -19.2152 -3.5181 4.4492 13.2566 28.6721 y[7] -19.1155 -3.2005 4.6943 12.0731 26.8739 y[8] -31.5715 -8.0353 4.6702 17.8301 41.3367
And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo
,
loglikelihoods = Turing.pointwise_loglikelihoods(param_mod, turing_chns)
Dict{String,Array{Float64,2}} with 8 entries: "y[6]" => [-3.36285 -4.08918 -3.39046 -3.32844; -3.81335 -3.62125 -3.38203 -3… "y[2]" => [-3.23247 -3.22337 -3.23758 -3.65949; -3.28484 -3.26335 -3.23609 -3… "y[1]" => [-4.87426 -4.878 -4.71977 -5.22698; -4.08751 -4.70315 -4.60937 -4.7… "y[5]" => [-3.2681 -4.81419 -3.37238 -3.16239; -3.43789 -3.95455 -3.39972 -3.… "y[8]" => [-4.03523 -3.81134 -3.86919 -3.86369; -3.81057 -3.82789 -3.86432 -3… "y[7]" => [-3.72935 -4.31431 -3.92227 -4.56106; -3.4916 -3.48724 -3.94398 -4.… "y[3]" => [-4.06155 -3.84478 -3.85029 -3.69183; -3.73377 -3.69364 -3.85202 -3… "y[4]" => [-4.04857 -3.33035 -3.31907 -3.32388; -3.75523 -3.33412 -3.31756 -3…
This can then be included in the from_mcmcchains
call from above:
using LinearAlgebra
# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
ynames = string.(keys(posterior_predictive))
loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames)
# Reshape into `(nchains, nsamples, size(y)...)`
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3))
idata = from_mcmcchains(
turing_chns;
posterior_predictive=posterior_predictive,
log_likelihood=Dict("y" => loglikelihoods_arr),
prior=prior,
prior_predictive=prior_predictive,
observed_data=Dict("y" => y),
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float642.242 7.977 5.731 ... 2.973 7.224
array([[ 2.24160829, 7.97734147, 5.73135114, ..., 6.12018333, 9.53684758, 3.0869606 ], [ 7.70653513, 7.5573304 , 6.62725568, ..., 8.3006838 , 8.3006838 , 8.3006838 ], [ 6.04538605, 5.91252989, 5.91252989, ..., 5.97777649, 6.45606721, 4.79280973], [-0.58799583, 4.90726082, 3.28986596, ..., 4.0850749 , 2.97250336, 7.22432856]])
- τ(chain, draw)float645.526 6.236 4.726 ... 1.253 2.106
array([[ 5.52591161, 6.23563268, 4.72612647, ..., 9.87097099, 8.8219826 , 17.58451073], [ 3.55699708, 6.73662161, 5.81125028, ..., 0.97136651, 0.97136651, 0.97136651], [ 0.43531242, 0.44324977, 0.44324977, ..., 2.5246634 , 2.38950769, 2.27977944], [ 7.61182381, 2.90622541, 2.70740599, ..., 2.24902501, 1.252966 , 2.1055964 ]])
- θ(chain, draw, school)float644.309 9.479 10.76 ... 8.703 8.909
array([[[ 4.30877987, 9.4794789 , 10.76416115, ..., -2.33705503, 7.92203023, -0.09939932], [13.60428579, 4.44158389, 1.6507143 , ..., 11.9615949 , 10.65045212, 12.90262932], [10.63814285, 11.87368711, 14.07732561, ..., 0.20207278, 10.98781336, 5.15780952], ..., [ 4.38811174, 15.16862573, 10.64180499, ..., 6.66538335, 11.62484015, 11.23178674], [-0.38262679, 18.05730492, 10.36870154, ..., -6.7560735 , 12.55728159, 15.57971579], [36.43114198, -1.24230296, -0.70212953, ..., 17.85756248, 21.79980835, 34.04836302]], [[ 4.27332986, 7.39208439, 5.85803122, ..., 14.67141202, 3.21629742, 13.14755757], [ 5.99380625, 5.10774998, -1.96035688, ..., 9.58305306, 10.71001478, 8.5298023 ], [ 7.63586201, 5.90108043, -2.00454104, ..., 7.47147537, 9.85714863, 7.28028336], ... [ 1.73062512, 9.74873023, 9.74717951, ..., 6.07453874, 8.36211683, 5.16644616], [ 2.92727608, 8.74696468, 8.79441179, ..., 6.08095603, 8.56338567, 6.31367838], [10.27793469, 4.45182344, 6.49537154, ..., 3.73623045, 5.22337011, 2.97712211]], [[ 1.16726342, -1.35914189, -3.39638863, ..., 2.67563877, 1.63211114, 6.06372414], [ 5.18270253, 7.15130497, 2.25004445, ..., 2.60926594, 5.06487665, 7.49008145], [ 3.38510145, 1.89417128, 3.11468045, ..., 9.35274262, 0.94524826, 5.52889926], ..., [ 3.49818772, 3.70674944, 6.27065453, ..., 3.15192986, 0.17280284, 1.84192382], [ 4.82664672, 5.56640453, 3.50106543, ..., 3.43926856, 4.41246824, 2.59073372], [ 7.21027734, 5.77350291, 7.65241325, ..., 8.38882264, 8.70252592, 8.90851363]]])
- created_at :
- 2021-01-20T07:27:45.372934
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 2.242 7.977 5.731 5.574 ... 4.085 2.973 7.224 τ (chain, draw) float64 5.526 6.236 4.726 6.596 ... 2.249 1.253 2.106 θ (chain, draw, school) float64 4.309 9.479 10.76 ... 8.703 8.909 Attributes: created_at: 2021-01-20T07:27:45.372934 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float6425.57 11.68 13.62 ... -5.516 40.3
array([[[ 25.5700303 , 11.67932651, 13.62190539, ..., 1.40038564, 5.65827707, -51.74313012], [ 14.82969598, 26.11254196, -12.6634658 , ..., 39.0068424 , -3.25472848, -12.88749672], [-28.43891354, 4.80695996, 22.58351622, ..., 13.65824083, -11.46306309, 3.12020982], ..., [ -6.91700421, 3.57637614, 4.18921425, ..., -14.47818827, -12.34989333, -0.19706251], [ -5.15475065, 19.00462409, 28.48031176, ..., -1.41772759, 23.09343207, 24.78549613], [ 40.41770656, 26.10480214, 4.10712542, ..., -11.76626869, -23.37285355, 6.22773957]], [[ 12.01244925, 12.18470838, 1.49011359, ..., 3.96736653, 10.43412115, 31.72471167], [ 30.19668349, 11.32106193, 0.52289026, ..., 4.78411075, 22.6802703 , 19.68819601], [ 12.33207417, 16.41485433, 7.9562466 , ..., -8.6731327 , -15.72266058, -13.08182757], ... [ 10.20641131, 16.12676827, 13.28965567, ..., 15.64142509, 24.99533498, -18.13317111], [ 22.34863749, 7.66265231, 10.24598026, ..., -3.06894082, 1.73713067, -6.87558068], [ 13.0204552 , 2.9656822 , 14.83156714, ..., 2.92295821, -11.2272121 , -19.44922518]], [[ 8.27736174, 13.89837653, -6.15539776, ..., -2.34268982, -3.75423169, 46.69570757], [-15.84020642, 13.05898641, -8.14458578, ..., 2.6644191 , 8.24985281, 12.0105247 ], [-12.17837048, -2.42410619, 13.55208048, ..., 0.29848117, -1.80456914, 3.39344483], ..., [ 24.08018679, -2.36387707, -5.83581795, ..., -0.86842488, -10.98263354, -1.31264776], [ 3.40316027, 2.77112895, 12.91687988, ..., -15.22114764, -2.64325525, -10.20992694], [ 33.21503171, 9.95836539, 13.68104899, ..., 23.38576235, -5.51566262, 40.30181596]]])
- created_at :
- 2021-01-20T07:27:44.867349
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 25.57 11.68 13.62 ... -5.516 40.3 Attributes: created_at: 2021-01-20T07:27:44.867349 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float64-4.874 -3.232 ... -3.654 -3.824
array([[[-4.87426409, -3.23246792, -4.06155095, ..., -3.36285007, -3.729351 , -4.03522922], [-4.08751448, -3.28483525, -3.73377168, ..., -3.81334853, -3.4916029 , -3.81056761], [-4.29684225, -3.29655089, -4.26112696, ..., -3.31946475, -3.46737743, -3.88155654], ..., [-4.86592488, -3.4784696 , -4.05500156, ..., -3.44946425, -3.42473694, -3.81022102], [-5.41715208, -3.72727054, -4.04059402, ..., -3.56541511, -3.36963954, -3.82908555], [-3.78495352, -3.64862445, -3.70184016, ..., -4.49112064, -3.29371634, -4.55951139]], [[-4.87799957, -3.22337143, -3.84477866, ..., -4.08917887, -4.31431294, -3.81134253], [-4.70314999, -3.26334918, -3.69363831, ..., -3.62125033, -3.48724305, -3.82789404], [-4.5485401 , -3.24355094, -3.69346268, ..., -3.48989163, -3.55305377, -3.84368641], ... [-5.16049997, -3.23681391, -4.00889168, ..., -3.42324266, -3.68596759, -3.88137427], [-5.02396981, -3.22431341, -3.96322286, ..., -3.42351196, -3.66677208, -3.85920883], [-4.32492562, -3.28447141, -3.86762507, ..., -3.34777165, -4.03773498, -3.9349466 ]], [[-5.2269793 , -3.65949131, -3.69183414, ..., -3.32843614, -4.56106255, -3.86369204], [-4.78394221, -3.22512504, -3.74536117, ..., -3.3275352 , -4.05811071, -3.8406982 ], [-4.97341814, -3.40792935, -3.76455327, ..., -3.6051326 , -4.67584641, -3.87393243], ..., [-4.96107497, -3.31368363, -3.85938865, ..., -3.33596935, -4.81056842, -3.96854873], [-4.82033163, -3.25113556, -3.77407384, ..., -3.34142071, -4.14462872, -3.94593728], [-4.58746111, -3.24631007, -3.91315598, ..., -3.54243174, -3.65373875, -3.82405919]]])
- created_at :
- 2021-01-20T07:27:45.314610
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 -4.874 -3.232 ... -3.654 -3.824 Attributes: created_at: 2021-01-20T07:27:45.314610 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- energy(chain, draw)float6464.14 65.4 63.17 ... 54.05 54.95
array([[64.13639067, 65.39997463, 63.16767838, ..., 69.01042882, 70.88577962, 75.42543421], [65.56948072, 64.86390545, 59.36188167, ..., 47.34652191, 46.29628436, 46.48431619], [46.43022352, 44.5069777 , 44.68840271, ..., 64.47975428, 55.60878398, 56.6490265 ], [63.69664605, 63.29229268, 59.0733094 , ..., 56.77057146, 54.04838171, 54.95284939]])
- energy_error(chain, draw)float64-0.02879 0.04508 ... -0.05148
array([[-0.02878785, 0.04508437, 0.04398762, ..., 0.25193447, -0.13875936, 0.0248281 ], [ 0.04114259, -0.07210876, -0.03569514, ..., 0. , 0. , 0. ], [ 0. , 1.38714563, 0. , ..., -0.119348 , -0.03731234, -0.00836846], [ 0.11732562, -0.12819408, -0.02658204, ..., 0.16327251, 0.01176959, -0.05147846]])
- depth(chain, draw)int644 4 4 4 4 4 4 4 ... 3 4 4 3 4 4 3 5
array([[4, 4, 4, ..., 4, 4, 5], [3, 3, 3, ..., 4, 3, 2], [2, 2, 2, ..., 5, 4, 4], [5, 5, 4, ..., 4, 3, 5]], dtype=int64)
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, True], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- log_density(chain, draw)float64-59.74 -59.23 ... -48.34 -49.35
array([[-59.74160388, -59.2265105 , -59.03096837, ..., -61.85483472, -65.88874085, -70.76454453], [-61.71594381, -58.74261962, -57.36217811, ..., -43.89145643, -43.89145643, -43.89145643], [-39.85221113, -43.33179921, -43.33179921, ..., -54.1321832 , -51.73064442, -51.45793823], [-59.17600693, -51.6529786 , -53.72093712, ..., -51.76150662, -48.33558787, -49.34883006]])
- max_energy_error(chain, draw)float64-0.03812 0.08593 ... -0.6883
array([[-3.81178848e-02, 8.59268443e-02, 7.11018702e-02, ..., 4.21680050e-01, -5.62221150e-01, 2.96399569e-02], [ 4.08600945e-01, -2.68954786e-01, -1.02665619e-01, ..., 3.09650231e+02, 3.96431593e+02, 3.69852292e+03], [ 2.26605542e+02, 4.57404023e+01, 1.18579914e+02, ..., 1.27918962e-01, -2.59461110e-01, 2.64874705e-02], [ 2.52047847e-01, -2.06878158e-01, 3.77250974e-01, ..., 3.33315681e-01, -7.24359721e-01, -6.88308151e-01]])
- nom_step_size(chain, draw)float640.2857 0.2857 ... 0.09645 0.09645
array([[0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896]])
- is_accept(chain, draw)float641.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
array([[1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.], [1., 1., 1., ..., 1., 1., 1.]])
- tree_size(chain, draw)int6415 15 15 15 15 ... 15 15 15 15 47
array([[15, 15, 15, ..., 31, 15, 31], [ 7, 7, 7, ..., 15, 9, 5], [ 3, 3, 3, ..., 31, 15, 31], [31, 63, 15, ..., 15, 15, 47]], dtype=int64)
- lp(chain, draw)float64-59.74 -59.23 ... -48.34 -49.35
array([[-59.74160388, -59.2265105 , -59.03096837, ..., -61.85483472, -65.88874085, -70.76454453], [-61.71594381, -58.74261962, -57.36217811, ..., -43.89145643, -43.89145643, -43.89145643], [-39.85221113, -43.33179921, -43.33179921, ..., -54.1321832 , -51.73064442, -51.45793823], [-59.17600693, -51.6529786 , -53.72093712, ..., -51.76150662, -48.33558787, -49.34883006]])
- step_size(chain, draw)float640.2857 0.2857 ... 0.09645 0.09645
array([[0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.28568664, 0.28568664, 0.28568664, ..., 0.28568664, 0.28568664, 0.28568664], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896], [0.09644896, 0.09644896, 0.09644896, ..., 0.09644896, 0.09644896, 0.09644896]])
- mean_tree_accept(chain, draw)float640.9954 0.9501 ... 0.959 0.9753
array([[0.995424 , 0.95007706, 0.97187014, ..., 0.82397683, 0.90121431, 0.98599513], [0.88492336, 0.99927971, 1. , ..., 0.00135985, 0.01078467, 0.02874953], [0.01452575, 0.08348777, 0.04816339, ..., 0.9605129 , 0.98682368, 0.99467091], [0.91518714, 0.99255986, 0.91560456, ..., 0.82760414, 0.95902893, 0.97527452]])
- created_at :
- 2021-01-20T07:27:45.376715
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999 Data variables: energy (chain, draw) float64 64.14 65.4 63.17 ... 54.05 54.95 energy_error (chain, draw) float64 -0.02879 0.04508 ... -0.05148 depth (chain, draw) int64 4 4 4 4 4 4 4 4 4 ... 3 4 4 3 4 4 3 5 diverging (chain, draw) bool False False False ... False False False log_density (chain, draw) float64 -59.74 -59.23 ... -48.34 -49.35 max_energy_error (chain, draw) float64 -0.03812 0.08593 ... -0.7244 -0.6883 nom_step_size (chain, draw) float64 0.2857 0.2857 ... 0.09645 0.09645 is_accept (chain, draw) float64 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 tree_size (chain, draw) int64 15 15 15 15 15 15 ... 15 15 15 15 47 lp (chain, draw) float64 -59.74 -59.23 ... -48.34 -49.35 step_size (chain, draw) float64 0.2857 0.2857 ... 0.09645 0.09645 mean_tree_accept (chain, draw) float64 0.9954 0.9501 ... 0.959 0.9753 Attributes: created_at: 2021-01-20T07:27:45.376715 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float64-3.59 7.95 2.976 ... -5.66 -3.953
array([[-3.59049788e+00, 7.95026062e+00, 2.97620369e+00, 4.15304835e+00, -1.56851632e-02, 7.06004886e-01, 1.61424499e+00, 9.56224673e+00, -3.80212846e+00, -3.08589532e+00, 2.47023042e+00, 1.44504545e+00, 9.32418487e+00, -4.09958366e+00, 5.99627884e+00, -2.98071493e-01, 4.85677224e+00, 6.14730668e+00, 2.77976384e+00, 5.53429615e+00, 1.21128359e+00, -2.63744096e+00, 2.78311062e+00, -4.00585004e+00, -5.47272343e+00, -1.12369573e+01, 5.91415709e-01, 4.57679713e-01, 4.00621487e+00, -1.34472084e+00, 9.99506182e+00, 3.79629292e+00, 1.87197809e+00, -1.73925003e+00, -1.23439844e+00, 1.46542176e-01, -5.79590674e+00, 2.80453991e+00, -3.64813615e+00, 7.18960349e+00, 3.01457288e+00, -3.71530854e+00, 6.00773666e-01, -2.69393331e+00, 5.65100986e+00, -8.56973796e-01, 4.19121675e+00, 8.33190230e+00, -1.91180729e+00, 5.87081877e+00, 2.99493085e+00, 5.22862035e+00, -2.47399224e+00, 8.10850985e+00, 2.68762730e+00, -5.18493328e+00, 5.91365278e+00, -3.21137170e+00, 7.78951447e+00, 7.16485604e+00, ... 1.12160416e+01, 6.65819185e+00, -1.66191549e+01, 3.22996936e+00, -4.21696566e+00, 2.28453512e+00, 1.18649102e+00, 3.48708358e+00, 2.28998126e+00, 3.98848482e+00, -1.31048654e+00, -1.18928385e+00, -3.34247257e+00, -1.80480947e-01, 4.00615649e-01, 4.12266245e+00, -1.70210840e+00, 1.19917915e+01, -1.03618689e+00, 4.76142069e+00, 1.51736722e+00, 8.99919764e+00, 9.70395402e+00, -5.38469703e+00, -3.43149653e+00, -5.92098688e+00, 1.96466802e+00, 6.61378477e+00, 1.70687526e+00, -2.28794476e+00, -1.42751743e+00, -2.92711330e+00, -3.19807819e+00, -4.94786051e-01, -3.88592061e+00, -1.64633191e-01, 4.32713312e+00, -1.64544946e+00, 6.97281403e+00, 8.59480636e+00, -4.64213906e+00, -3.04710135e+00, -3.67379816e+00, -9.98415750e+00, -4.61034155e+00, -2.72353280e+00, -3.82670451e+00, 2.58965231e+00, -1.21466965e+00, 4.34277178e-01, -3.18626423e+00, 9.23374211e+00, -3.03715580e+00, 4.63716912e+00, -7.28493800e+00, -2.54143663e+00, -5.65959730e+00, -3.95286661e+00]])
- τ(chain, draw)float644.819 0.8977 2.399 ... 7.004 0.5758
array([[4.81884993e+00, 8.97733173e-01, 2.39859179e+00, 1.28773605e+00, 1.38676281e+01, 4.98972264e+00, 3.30157022e+00, 6.86887076e+00, 1.91881731e+00, 9.68806040e+00, 6.14419293e-01, 1.14459886e+01, 1.58838854e+00, 2.08436290e+00, 1.52218692e+01, 3.17030462e+00, 1.29364665e+01, 2.01153882e-01, 6.70147145e+00, 1.24382300e+00, 1.14294081e+01, 5.52574588e+00, 8.88630627e+00, 1.90325321e+00, 6.75336297e+00, 4.10157776e+00, 6.03526572e+00, 6.14225388e+00, 5.33484564e+00, 3.71680082e+00, 9.64913889e+00, 4.56437017e-01, 1.74973226e+01, 1.70639352e+02, 1.26131458e+01, 8.24247506e+00, 2.79678929e+00, 3.48045109e+01, 1.58039670e+00, 6.92327527e+01, 1.28863417e+00, 3.66415718e+01, 2.44848640e+01, 1.57052528e+01, 2.50969129e+00, 1.25535526e+00, 1.82629412e+01, 6.82445406e-02, 5.15714502e+00, 4.09524977e-01, 5.32335280e+00, 1.12199478e+00, 1.55585371e-01, 2.79114392e+00, 4.87909955e-01, 4.09842409e+00, 2.63447129e+00, 1.67604497e+01, 1.08017244e+01, 7.69675201e+00, 9.31970008e+00, 6.30299690e+00, 3.81239773e+01, 2.06985909e+00, 2.63934514e+01, 2.39845335e+00, 2.26791140e+00, 3.84554376e+00, 6.58160040e+00, 7.05699998e+00, 3.87081538e-01, 3.41119545e+01, 5.17251063e+00, 1.17095255e+00, 2.14936918e+01, 3.16567059e+00, 1.27347530e+00, 6.25983995e-01, 1.60131005e+01, 1.13811008e+02, ... 4.12910725e+00, 2.81493666e+01, 6.12315829e+00, 6.72312522e+00, 2.53806102e+01, 4.10873852e+00, 6.47063422e+00, 3.23628213e+00, 3.42507662e+00, 1.29337679e+00, 3.66905854e+00, 8.16915390e+00, 1.49970755e+01, 7.32889559e-01, 1.92304151e+00, 1.53739599e+01, 3.54872383e+00, 1.57516045e+02, 3.03690612e+00, 1.62171580e+01, 7.13533113e+00, 1.35969030e+00, 2.86263014e+00, 1.24772063e+00, 1.15592543e+01, 8.16265197e-01, 2.13978962e+00, 3.02809975e+01, 9.84234267e-01, 3.38453171e+01, 3.65845839e-01, 5.41301871e+00, 4.16536559e+00, 1.02145731e+02, 1.14466829e+00, 1.61988902e+01, 4.30893289e+02, 2.54488803e+00, 3.52390958e-01, 2.41683562e+01, 3.51621602e+00, 2.10369584e+00, 7.60672634e-01, 1.58586275e+01, 3.95675303e+01, 3.32972733e+01, 4.03611956e+01, 8.00278225e+00, 2.33875958e+01, 1.18578632e+02, 2.37229850e+00, 7.63755710e+00, 8.27368914e+00, 9.08397524e-01, 8.04427350e-01, 7.67242634e+00, 2.10103802e+01, 7.81080220e+00, 2.26429686e+01, 1.91331688e+00, 1.86402753e+00, 2.29867743e+00, 7.24526987e+00, 9.28525431e+00, 2.08594350e+00, 3.94700041e+01, 3.29835423e+00, 2.66453444e+01, 3.23033080e-01, 5.13079793e+00, 1.46707926e+01, 1.05133905e+00, 5.81640289e+00, 5.65471255e+00, 4.21499589e+00, 3.00271595e+00, 1.35288170e+00, 9.56464214e+00, 7.00435401e+00, 5.75756888e-01]])
- θ(chain, draw, school)float641.91 -4.607 6.231 ... -4.287 -3.628
array([[[ 1.90975985, -4.6071557 , 6.23112117, ..., -11.21196169, -2.33539207, -9.06592627], [ 8.54638091, 8.27098869, 7.6968387 , ..., 7.00056815, 8.69715022, 7.06790238], [ 1.22735243, 4.38517634, 5.73030852, ..., 0.7869123 , 3.11181587, 1.8032994 ], ..., [-10.47571615, -14.33015605, -1.39270852, ..., 14.5153871 , 6.43043497, 10.04137878], [-10.84172363, -10.22261348, -11.76901318, ..., -7.77199587, -2.01104341, -7.08175869], [ -3.03446671, -4.84907766, -4.17764867, ..., -3.13644579, -4.28682758, -3.62847765]]])
- created_at :
- 2021-01-20T07:27:45.785334
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 -3.59 7.95 2.976 ... -2.541 -5.66 -3.953 τ (chain, draw) float64 4.819 0.8977 2.399 ... 9.565 7.004 0.5758 θ (chain, draw, school) float64 1.91 -4.607 6.231 ... -4.287 -3.628 Attributes: created_at: 2021-01-20T07:27:45.785334 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float6412.63 -22.35 29.94 ... 1.514 -8.327
array([[[ 12.62848462, -22.3477345 , 29.93614481, ..., -15.13850355, -12.78712049, 5.42059261], [ 11.8116553 , 30.93052116, -4.46840984, ..., 10.57486026, -5.25503283, -0.23389063], [-22.74278456, -5.59637459, -24.2381273 , ..., 4.8243343 , 3.28511459, -8.27501221], ..., [ 28.54684675, -5.42291959, -9.30634464, ..., 43.88529904, -21.64492016, -38.84482536], [ 12.99417809, 0.13681161, -27.08248824, ..., -8.19076158, -22.1101297 , 36.00820041], [ 1.3677576 , 0.96085418, 20.44624169, ..., -1.60639368, 1.51352274, -8.3267067 ]]])
- created_at :
- 2021-01-20T07:27:45.643680
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 12.63 -22.35 29.94 ... 1.514 -8.327 Attributes: created_at: 2021-01-20T07:27:45.643680 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- chain: 1
- draw: 1000
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- lp(chain, draw)float64-65.73 -43.9 ... -66.78 -45.63
array([[-6.57329911e+01, -4.38952031e+01, -5.31535068e+01, -4.91726090e+01, -7.61334412e+01, -6.16372207e+01, -5.56738646e+01, -6.29463549e+01, -5.53808296e+01, -6.73046258e+01, -4.39364836e+01, -7.82075477e+01, -5.21344563e+01, -5.27210062e+01, -7.38509928e+01, -6.01945453e+01, -7.12636533e+01, -3.29611825e+01, -6.19601437e+01, -4.78273286e+01, -8.37823334e+01, -6.28969071e+01, -7.74743053e+01, -5.44797317e+01, -6.54240789e+01, -7.18284104e+01, -6.30744892e+01, -6.15575806e+01, -6.47191015e+01, -5.91905899e+01, -7.07655808e+01, -3.91692483e+01, -9.95454211e+01, -4.13446435e+02, -7.38509873e+01, -7.09619927e+01, -5.99568581e+01, -1.02297556e+02, -5.24911012e+01, -1.33323338e+02, -4.62093305e+01, -9.44512138e+01, -8.28895495e+01, -1.02671855e+02, -5.63612313e+01, -4.94888816e+01, -7.11981639e+01, -2.35146730e+01, -6.11739230e+01, -3.79701631e+01, -5.78523912e+01, -4.73329529e+01, -3.15685595e+01, -5.45136571e+01, -4.09646392e+01, -6.38540760e+01, -5.38385552e+01, -7.49075934e+01, -6.49509713e+01, -6.43199844e+01, ... -5.40780287e+01, -4.77883505e+01, -9.68476240e+01, -4.27113754e+01, -5.66286749e+01, -1.10594800e+02, -4.72260044e+01, -1.21787696e+02, -3.85141508e+01, -5.86691223e+01, -5.99895226e+01, -1.92565413e+02, -4.94912857e+01, -8.49842334e+01, -6.20889471e+03, -5.39858027e+01, -4.09366370e+01, -9.53503248e+01, -5.92166706e+01, -6.00699149e+01, -4.41113355e+01, -8.63946789e+01, -9.64334891e+01, -1.34908886e+02, -1.77652345e+02, -7.22345534e+01, -9.70088348e+01, -3.59262787e+02, -5.10444544e+01, -7.64752232e+01, -6.81652657e+01, -4.66566208e+01, -4.62291203e+01, -6.34491896e+01, -8.93729800e+01, -6.43215540e+01, -8.55891651e+01, -5.34042855e+01, -5.11146232e+01, -5.35001971e+01, -6.60485973e+01, -6.73548383e+01, -5.53512017e+01, -1.86569596e+02, -6.22253178e+01, -1.14181099e+02, -4.33759113e+01, -6.05976552e+01, -8.24161533e+01, -5.13290156e+01, -6.18986569e+01, -6.31529312e+01, -6.09073152e+01, -5.62415222e+01, -5.52647564e+01, -7.19762237e+01, -6.67775662e+01, -4.56270264e+01]])
- created_at :
- 2021-01-20T07:27:45.836302
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 1000) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 Data variables: lp (chain, draw) float64 -65.73 -43.9 -53.15 ... -71.98 -66.78 -45.63 Attributes: created_at: 2021-01-20T07:27:45.836302 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-01-20T07:27:46.612903
- arviz_version :
- 0.11.0
- inference_library :
- Turing
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-01-20T07:27:46.612903 arviz_version: 0.11.0 inference_library: Turing
Dataset (xarray.Dataset)
Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:
loo(idata) # higher is better
loo | loo_se | p_loo | n_samples | n_data_points | warning | loo_scale | |
---|---|---|---|---|---|---|---|
Float64 | Float64 | Float64 | Int64 | Int64 | Bool | String | |
1 | -30.7109 | 1.34611 | 0.881344 | 4000 | 8 | 0 | log |
If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:
plot_loo_pit(idata; y="y", ecdf=true);
gcf()
Plotting with CmdStan.jl outputs
CmdStan.jl and StanSample.jl also default to producing Chains
outputs, and we can easily plot these chains.
Here is the same centered eight schools model:
using CmdStan, MCMCChains
schools_code = """
data {
int<lower=0> J;
real y[J];
real<lower=0> sigma[J];
}
parameters {
real mu;
real<lower=0> tau;
real theta[J];
}
model {
mu ~ normal(0, 5);
tau ~ cauchy(0, 5);
theta ~ normal(mu, tau);
y ~ normal(theta, sigma);
}
generated quantities {
vector[J] log_lik;
vector[J] y_hat;
for (j in 1:J) {
log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
y_hat[j] = normal_rng(theta[j], sigma[j]);
}
}
"""
schools_dat = Dict("J" => J, "y" => y, "sigma" => σ)
stan_model = Stanmodel(;
model=schools_code,
name="schools",
nchains=nchains,
num_warmup=nwarmup,
num_samples=nsamples,
output_format=:mcmcchains,
random=CmdStan.Random(28983),
)
_, stan_chns, _ = stan(stan_model, schools_dat; summary=false);
File /home/runner/work/ArviZ.jl/ArviZ.jl/docs/build/tmp/schools.stan will be updated.
plot_density(stan_chns; var_names=["mu", "tau"]);
gcf()
Again, converting to InferenceData
, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan
function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains
object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains
.
idata = from_cmdstan(
stan_chns;
posterior_predictive="y_hat",
observed_data=Dict("y" => schools_dat["y"]),
log_likelihood="log_lik",
coords=Dict("school" => schools),
dims=Dict(
"y" => ["school"],
"sigma" => ["school"],
"theta" => ["school"],
"log_lik" => ["school"],
"y_hat" => ["school"],
),
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- theta(chain, draw, school)float6413.28 6.15 -3.49 ... 3.837 3.481
array([[[13.281 , 6.14955 , -3.49044 , ..., 14.2943 , 5.58957 , 0.917792], [13.6458 , 10.9449 , 1.00479 , ..., 11.1783 , 7.11129 , 5.68317 ], [ 6.18894 , 3.40463 , 1.98471 , ..., -3.2666 , 3.65978 , 3.26795 ], ..., [ 1.69686 , 0.215756, 3.54447 , ..., 8.03831 , 11.086 , 1.47824 ], [13.14 , 8.03016 , 5.32461 , ..., 4.44722 , 3.16555 , 9.68102 ], [ 6.63027 , 13.4318 , 11.6217 , ..., 11.9196 , 18.3588 , 8.98325 ]], [[ 2.46516 , 1.48766 , 1.22253 , ..., 6.77713 , 3.77121 , 1.95551 ], [ 4.70399 , 5.57669 , 2.95547 , ..., 2.77001 , 5.3974 , 1.3124 ], [ 4.29044 , 0.86441 , 2.1943 , ..., 3.93474 , 2.037 , 5.92329 ], ... [23.6189 , -1.19293 , 11.5358 , ..., 6.26323 , 15.3346 , 13.2318 ], [ 9.34113 , 11.6592 , 4.93338 , ..., 9.87979 , 13.3323 , 11.6114 ], [ 7.46112 , 3.82941 , 3.04214 , ..., 4.82692 , 7.44271 , 3.56585 ]], [[ 2.14833 , 2.32601 , 4.66894 , ..., 1.11174 , 6.66286 , 1.96396 ], [10.2434 , 7.70217 , 4.54012 , ..., 7.62897 , 10.7992 , 7.18282 ], [ 8.4856 , 8.7239 , 6.34584 , ..., 8.75602 , 9.36591 , 4.81794 ], ..., [ 7.34957 , 8.69803 , 11.7665 , ..., 6.11187 , 5.02696 , 3.94774 ], [ 7.51265 , 3.47181 , 9.45579 , ..., 8.92492 , 5.10388 , 5.76037 ], [ 2.50449 , -3.73057 , 4.35904 , ..., 2.70827 , 3.83733 , 3.48125 ]]])
- tau(chain, draw)float646.678 4.777 2.97 ... 1.678 4.029
array([[6.67845, 4.7768 , 2.96962, ..., 3.27016, 3.74138, 4.61512], [2.3308 , 1.6659 , 3.36116, ..., 2.50892, 2.14004, 2.14004], [1.93551, 1.21759, 1.50433, ..., 7.10705, 3.26133, 3.92861], [2.8361 , 1.64312, 3.09022, ..., 3.45327, 1.67776, 4.02851]])
- mu(chain, draw)float649.217 9.088 1.585 ... 6.242 0.2594
array([[ 9.21662 , 9.08785 , 1.58504 , ..., 3.46423 , 7.26019 , 10.9387 ], [ 2.7689 , 3.83577 , 3.93 , ..., 7.94832 , 7.82557 , 7.82557 ], [ 3.76902 , 2.31916 , 2.92937 , ..., 11.0237 , 7.49704 , 7.25992 ], [ 1.51756 , 8.18698 , 6.88178 , ..., 6.84279 , 6.24203 , 0.259424]])
- created_at :
- 2021-01-20T07:28:30.029410
- arviz_version :
- 0.11.0
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: theta (chain, draw, school) float64 13.28 6.15 -3.49 ... 3.837 3.481 tau (chain, draw) float64 6.678 4.777 2.97 4.605 ... 3.453 1.678 4.029 mu (chain, draw) float64 9.217 9.088 1.585 2.41 ... 6.843 6.242 0.2594 Attributes: created_at: 2021-01-20T07:28:30.029410 arviz_version: 0.11.0 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y_hat(chain, draw, school)float64-3.194 -2.819 ... -4.648 -16.89
array([[[ -3.19363 , -2.81916 , 3.58726 , ..., 6.35537 , -14.0705 , -22.798 ], [ 27.7485 , -4.39934 , -15.3293 , ..., 3.58927 , 10.4446 , 51.622 ], [ 6.00143 , -8.66207 , -20.6493 , ..., -9.75178 , 8.21021 , 3.74418 ], ..., [ 10.3185 , 11.4671 , 0.719597, ..., 10.3993 , 9.01525 , 11.0161 ], [ 28.6229 , 15.3322 , 4.71178 , ..., 6.66364 , 14.6553 , -1.58482 ], [ -3.96206 , 18.1074 , 1.39344 , ..., 8.32147 , 16.1394 , 7.80368 ]], [[ -3.01536 , 3.46458 , 25.2584 , ..., 5.50963 , -3.21433 , 7.81655 ], [ 7.09713 , 10.8136 , 1.31508 , ..., 11.1758 , 5.93855 , -25.0451 ], [ -8.32332 , -10.1668 , 7.68051 , ..., 0.354672, 4.60651 , 13.4165 ], ... [-12.4193 , 5.51831 , 9.57278 , ..., 15.9227 , 4.97675 , 11.1276 ], [ 11.1209 , 15.1147 , 13.9565 , ..., 9.56139 , 11.344 , 8.10104 ], [ 36.2677 , -3.08551 , 24.8541 , ..., 9.87936 , -2.20321 , -16.4877 ]], [[ 11.1081 , 13.8004 , 22.3405 , ..., 29.7073 , 12.6176 , 0.586072], [ 21.4488 , 3.17952 , 0.273909, ..., 1.10515 , -2.96018 , 30.4921 ], [ 21.0881 , 13.1856 , 15.0494 , ..., -7.50626 , 4.42792 , 22.1122 ], ..., [ 16.2461 , 8.22093 , 32.6478 , ..., 0.690273, -11.2828 , -11.4927 ], [ 0.134421, -17.7098 , 1.73166 , ..., -7.87307 , 16.1333 , -8.67047 ], [ 34.0339 , -14.0496 , 12.4855 , ..., -9.66382 , -4.64801 , -16.89 ]]])
- created_at :
- 2021-01-20T07:28:29.994263
- arviz_version :
- 0.11.0
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y_hat (chain, draw, school) float64 -3.194 -2.819 3.587 ... -4.648 -16.89 Attributes: created_at: 2021-01-20T07:28:29.994263 arviz_version: 0.11.0 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- log_lik(chain, draw, school)float64-4.108 -3.239 ... -4.224 -3.921
array([[[-4.10843, -3.23864, -3.692 , ..., -4.04716, -3.99162, -3.99884], [-4.08486, -3.26489, -3.72285, ..., -3.74492, -3.81434, -3.87089], [-4.68415, -3.32711, -3.74006, ..., -3.39206, -4.24973, -3.92698], ..., [-5.16445, -3.5245 , -3.77518, ..., -3.52154, -3.46054, -3.98015], [-4.1177 , -3.22153, -3.82688, ..., -3.36594, -4.32183, -3.81761], [-4.6418 , -3.36905, -4.10909, ..., -3.80955, -3.22217, -3.82335]], [[-5.07594, -3.43358, -3.72635, ..., -3.45475, -4.23382, -3.96501], [-4.833 , -3.25089, -3.7608 , ..., -3.32978, -4.01565, -3.98558], [-4.8762 , -3.47611, -3.74422, ..., -3.35242, -4.49561, -3.8663 ], ... [-3.66964, -3.64407, -4.1042 , ..., -3.4313 , -3.25705, -3.81165], [-4.40066, -3.28847, -3.81445, ..., -3.64266, -3.33046, -3.80954], [-4.56442, -3.30849, -3.76283, ..., -3.37735, -3.77881, -3.91909]], [[-5.11212, -3.38249, -3.8064 , ..., -3.31689, -3.86418, -3.96475], [-4.32765, -3.22197, -3.80257, ..., -3.49842, -3.48078, -3.84512], [-4.47324, -3.22414, -3.86212, ..., -3.56541, -3.59426, -3.88891], ..., [-4.57463, -3.22396, -4.11741, ..., -3.42481, -4.06302, -3.90937], [-4.55973, -3.32405, -3.99455, ..., -3.57636, -4.05307, -3.86939], [-5.07148, -3.90955, -3.7973 , ..., -3.32889, -4.22443, -3.9213 ]]])
- created_at :
- 2021-01-20T07:28:30.006758
- arviz_version :
- 0.11.0
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: log_lik (chain, draw, school) float64 -4.108 -3.239 ... -4.224 -3.921 Attributes: created_at: 2021-01-20T07:28:30.006758 arviz_version: 0.11.0 inference_library: CmdStan
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- stepsize(chain, draw)float640.1873 0.1873 ... 0.1227 0.1227
array([[0.187258, 0.187258, 0.187258, ..., 0.187258, 0.187258, 0.187258], [0.261684, 0.261684, 0.261684, ..., 0.261684, 0.261684, 0.261684], [0.129394, 0.129394, 0.129394, ..., 0.129394, 0.129394, 0.129394], [0.12273 , 0.12273 , 0.12273 , ..., 0.12273 , 0.12273 , 0.12273 ]])
- treedepth(chain, draw)float644.0 4.0 4.0 4.0 ... 4.0 4.0 5.0 5.0
array([[4., 4., 4., ..., 4., 4., 4.], [3., 3., 3., ..., 3., 3., 2.], [4., 4., 3., ..., 5., 5., 4.], [4., 5., 3., ..., 4., 5., 5.]])
- n_leapfrog(chain, draw)float6415.0 15.0 15.0 ... 15.0 31.0 31.0
array([[15., 15., 15., ..., 15., 15., 15.], [15., 7., 7., ..., 15., 7., 3.], [19., 15., 15., ..., 31., 31., 15.], [31., 47., 15., ..., 15., 31., 31.]])
- accept_stat(chain, draw)float640.9738 0.9913 ... 0.9588 0.8692
array([[0.973789, 0.991337, 0.997903, ..., 0.979345, 1. , 0.937699], [0.903624, 0.378236, 0.600284, ..., 0.830712, 0.831896, 0.106895], [0.899472, 0.850818, 0.830222, ..., 0.996325, 0.979181, 0.955173], [0.20683 , 0.978267, 0.83992 , ..., 1. , 0.958838, 0.869216]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- energy(chain, draw)float6425.34 27.53 22.46 ... 22.62 18.96
array([[25.3428, 27.5281, 22.4636, ..., 20.6206, 22.1799, 22.3329], [16.968 , 14.5311, 16.4706, ..., 19.3846, 20.5221, 25.0621], [12.0014, 13.345 , 15.5569, ..., 27.667 , 25.67 , 20.3006], [16.1348, 17.4771, 18.8545, ..., 22.792 , 22.623 , 18.9642]])
- lp(chain, draw)float64-22.39 -18.73 ... -12.86 -15.88
array([[-22.3938 , -18.7331 , -17.0697 , ..., -18.5802 , -15.428 , -18.631 ], [-12.0592 , -10.1744 , -13.2691 , ..., -14.0125 , -15.6934 , -15.6934 ], [ -8.71764, -9.8377 , -6.88581, ..., -23.4288 , -16.4797 , -15.8865 ], [-12.8024 , -13.8941 , -12.5938 , ..., -15.8394 , -12.8585 , -15.879 ]])
- created_at :
- 2021-01-20T07:28:30.041882
- arviz_version :
- 0.11.0
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: stepsize (chain, draw) float64 0.1873 0.1873 0.1873 ... 0.1227 0.1227 treedepth (chain, draw) float64 4.0 4.0 4.0 4.0 4.0 ... 4.0 4.0 5.0 5.0 n_leapfrog (chain, draw) float64 15.0 15.0 15.0 15.0 ... 15.0 31.0 31.0 accept_stat (chain, draw) float64 0.9738 0.9913 0.9979 ... 0.9588 0.8692 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float64 25.34 27.53 22.46 ... 22.79 22.62 18.96 lp (chain, draw) float64 -22.39 -18.73 -17.07 ... -12.86 -15.88 Attributes: created_at: 2021-01-20T07:28:30.041882 arviz_version: 0.11.0 inference_library: CmdStan
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-01-20T07:28:30.079842
- arviz_version :
- 0.11.0
- inference_library :
- CmdStan
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-01-20T07:28:30.079842 arviz_version: 0.11.0 inference_library: CmdStan
Dataset (xarray.Dataset)
Here is a plot showing where the Hamiltonian sampler had divergences:
plot_pair(
idata;
coords=Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
divergences=true,
);
gcf()
Plotting with Soss.jl outputs
With Soss, we can define our model for the posterior and easily use it to draw samples from the prior, prior predictive, posterior, and posterior predictive distributions.
First we define our model:
using Soss, NamedTupleTools
mod = Soss.@model (J, σ) begin
μ ~ Normal(0, 5)
τ ~ HalfCauchy(5)
θ ~ iid(J)(Normal(μ, τ))
y ~ For(1:J) do j
Normal(θ[j], σ[j])
end
end
constant_data = (J=J, σ=σ)
param_mod = mod(; constant_data...)
Joint Distribution Bound arguments: [J, σ] Variables: [τ, μ, θ, y] @model (J, σ) begin τ ~ HalfCauchy(5) μ ~ Normal(0, 5) θ ~ (iid(J))(Normal(μ, τ)) y ~ For(1:J) do j Normal(θ[j], σ[j]) end end
Then we draw from the prior and prior predictive distributions.
Random.seed!(5298)
prior_priorpred = [
map(1:(nchains * nsamples)) do _
draw = rand(param_mod)
return delete(draw, keys(constant_data))
end,
];
Next, we draw from the posterior using DynamicHMC.jl.
post = map(1:nchains) do _
dynamicHMC(param_mod, (y=y,), nsamples)
end;
Finally, we update the posterior samples with draws from the posterior predictive distribution.
pred = predictive(mod, :μ, :τ, :θ)
post_postpred = map(post) do post_draws
map(post_draws) do post_draw
pred_draw = rand(pred(post_draw))
pred_draw = delete(pred_draw, keys(constant_data))
return merge(pred_draw, post_draw)
end
end;
Each Soss draw is a NamedTuple
. We can plot the rank order statistics of the posterior to identify poor convergence:
plot_rank(post; var_names=["μ", "τ"]);
gcf()
Now we combine all of the samples to an InferenceData
:
idata = from_namedtuple(
post_postpred;
posterior_predictive=[:y],
prior=prior_priorpred,
prior_predictive=[:y],
observed_data=(y=y,),
constant_data=constant_data,
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library=Soss,
)
-
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float64-0.2821 -2.515 ... 8.506 8.948
array([[-0.28214464, -2.5146648 , 3.40854213, ..., 2.89871658, -0.32541765, -0.97638201], [ 1.52079853, 3.65930143, 1.59586514, ..., 4.10105064, 5.37040261, 5.37040261], [ 5.08663443, 7.1342292 , 9.72473076, ..., 10.39107663, 9.93602774, 8.50375744], [ 2.08679365, 1.7409715 , 2.12393775, ..., 11.25517972, 8.50642236, 8.9481184 ]])
- τ(chain, draw)float646.09 10.71 3.243 ... 2.397 2.114
array([[ 6.0898954 , 10.71360392, 3.2431528 , ..., 9.75042656, 6.54005633, 5.80522361], [ 6.85840802, 12.96970109, 7.16880675, ..., 1.37078818, 0.65939383, 0.65939383], [ 2.4668581 , 4.2184674 , 7.04324656, ..., 6.61014664, 3.61711873, 5.46477205], [ 0.44556181, 0.65924781, 0.59543758, ..., 2.03293002, 2.39690126, 2.11393959]])
- θ(chain, draw, school)float646.039 -2.826 0.1498 ... 12.7 6.617
array([[[ 6.03878667, -2.82583714, 0.14983322, ..., -8.37443791, 8.813629 , -4.82915253], [ 6.01723944, -2.77958341, -13.2214421 , ..., -2.02922889, 20.30476673, 4.55935562], [ -0.17714326, 8.03724985, 6.03881204, ..., -1.92234049, -1.30469666, 6.74377697], ..., [ -3.17255935, 3.10727985, -1.13263795, ..., 0.29064241, 7.11975749, 8.26162282], [ 4.63041175, 1.97631302, 2.03484637, ..., -3.06064846, 5.04285631, -7.79821954], [ 4.32855725, -9.88257076, -5.3405121 , ..., 0.0664224 , 2.19654622, -10.25461406]], [[ 8.19416342, 2.58712522, -8.13732956, ..., -7.6282164 , 18.50306691, -5.17704439], [ 14.28009803, 5.99390955, 2.48183977, ..., 9.30323227, 14.832457 , -15.91785634], [ 8.50610893, 5.59358859, 1.45157815, ..., 9.28322493, 21.02155985, -15.07310758], ... [ 5.98596983, 11.02443811, 8.69058396, ..., 8.42788886, 11.77393431, 9.67420448], [ 6.51409384, 9.37097444, 6.55204643, ..., 7.26916889, 14.86024589, 6.03709608], [ 6.9793743 , 11.93454203, 2.47393128, ..., 2.88189285, 12.42949138, 3.84755606]], [[ 0.9804312 , 1.8210015 , 1.79074464, ..., 3.1602638 , 2.1733446 , 2.08598857], [ 2.10522437, 1.71185828, 1.32457403, ..., 1.00136917, 1.4135602 , 2.34976974], [ 2.54419662, 1.15126597, 0.9732046 , ..., 2.02364063, 2.45560508, 1.8173414 ], ..., [ 10.20466098, 9.59230374, 14.24345837, ..., 10.32578457, 11.67127931, 13.58483529], [ 5.75902235, 10.04760669, 7.49203869, ..., 7.61053501, 12.82677459, 8.9245289 ], [ 8.35547786, 8.45430481, 7.49651067, ..., 9.88411398, 12.69810273, 6.61661605]]])
- created_at :
- 2021-01-20T07:32:43.917335
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 -0.2821 -2.515 3.409 ... 11.26 8.506 8.948 τ (chain, draw) float64 6.09 10.71 3.243 3.181 ... 2.033 2.397 2.114 θ (chain, draw, school) float64 6.039 -2.826 0.1498 ... 12.7 6.617 Attributes: created_at: 2021-01-20T07:32:43.917335 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 4
- draw: 1000
- school: 8
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float640.9461 7.547 ... -1.261 -0.9876
array([[[ 0.94607229, 7.54729979, -22.30730454, ..., -0.11322108, -3.08732068, -38.52728298], [ -6.24356949, -10.25444133, -28.34626647, ..., 2.52005159, 22.75341532, 9.31254434], [ -5.4779209 , 19.40365733, 3.05288787, ..., -16.2502417 , 2.66990997, -13.48396905], ..., [ -7.65372462, -10.67945607, -15.64636166, ..., -14.12773291, 2.64934634, -15.99525608], [-22.62378225, 5.8479869 , 11.87684138, ..., 15.40998702, 1.19921548, 17.50793862], [ 11.12811831, -30.34965401, 0.16710192, ..., -8.05062218, 13.6247231 , -2.5247849 ]], [[ 31.25350683, -0.68529849, -10.00752951, ..., -3.10105352, 9.05511278, 26.17701292], [ -1.68299039, 11.59208442, -2.20674753, ..., -19.36453861, -13.4472867 , 10.31629483], [ 27.29615911, -6.6470613 , 18.27183291, ..., -5.24714685, 21.32207455, 3.01073703], ... [ -9.08127641, 27.08871844, 25.51845616, ..., 10.36305649, 3.89551128, 0.2070338 ], [ 4.27362967, 14.71562035, -39.69116014, ..., -6.17174574, 19.45988636, 18.30600633], [ -1.83494473, 21.80710972, 30.10490077, ..., 23.27737982, -1.66500217, -17.2353757 ]], [[ 5.6235742 , -0.647874 , 19.18283939, ..., 8.17325829, -3.71925731, -3.04111505], [ 9.05725096, 13.85786709, -15.398383 , ..., -1.17753074, 9.38138232, 9.42141304], [ 7.62841562, -1.17869104, 0.32437913, ..., 1.27142649, 24.13518069, 11.74607316], ..., [ 40.79567666, 10.16719473, 14.79030287, ..., 13.51524495, -0.21946407, 18.30913253], [ -3.26439718, -1.45879951, 23.65287767, ..., 4.52601881, 29.58383206, -3.73327002], [ -7.78904719, 21.85297394, 17.04047124, ..., -4.36404801, -1.26137998, -0.9875607 ]]])
- created_at :
- 2021-01-20T07:32:43.884761
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 0.9461 7.547 ... -1.261 -0.9876 Attributes: created_at: 2021-01-20T07:32:43.884761 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 1
- draw: 4000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 ... 3996 3997 3998 3999
array([ 0, 1, 2, ..., 3997, 3998, 3999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- μ(chain, draw)float64-1.058 6.136 ... -10.82 -0.7505
array([[ -1.05761955, 6.13590571, -0.38523938, ..., 1.48448254, -10.82011143, -0.75053337]])
- τ(chain, draw)float64149.7 0.02342 3.398 ... 84.28 4.543
array([[1.49711440e+02, 2.34168519e-02, 3.39824855e+00, ..., 4.34685582e+00, 8.42779009e+01, 4.54300278e+00]])
- θ(chain, draw, school)float64-8.511 154.9 ... -2.911 -6.732
array([[[ -8.51103855, 154.89324528, 44.86132777, ..., 51.29420022, 38.10463115, 74.88003827], [ 6.11829496, 6.1144437 , 6.09869911, ..., 6.10864983, 6.18038946, 6.15285673], [ 3.54669756, 0.93706638, -3.76834063, ..., 2.9794253 , 2.34742332, 2.82124655], ..., [ -4.16283704, 6.00855201, 2.00872495, ..., -0.22008474, 10.33295411, -1.18006416], [ 18.43190008, 16.84979632, -61.57858988, ..., 24.19601952, -28.84201014, 27.37869771], [ -0.29455833, 5.27378996, -2.10237174, ..., -9.3180781 , -2.91131027, -6.73238936]]])
- created_at :
- 2021-01-20T07:32:44.904443
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 4000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: μ (chain, draw) float64 -1.058 6.136 -0.3852 ... 1.484 -10.82 -0.7505 τ (chain, draw) float64 149.7 0.02342 3.398 ... 4.347 84.28 4.543 θ (chain, draw, school) float64 -8.511 154.9 44.86 ... -2.911 -6.732 Attributes: created_at: 2021-01-20T07:32:44.904443 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset) -
- chain: 1
- draw: 4000
- school: 8
- chain(chain)int640
array([0])
- draw(draw)int640 1 2 3 4 ... 3996 3997 3998 3999
array([ 0, 1, 2, ..., 3997, 3998, 3999])
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(chain, draw, school)float64-19.44 137.9 18.57 ... 5.947 13.3
array([[[-1.94418447e+01, 1.37915819e+02, 1.85671660e+01, ..., 5.57225345e+01, 3.01545494e+01, 7.93412175e+01], [ 9.90314022e+00, 2.77866851e+00, 2.37646213e+01, ..., 4.00280550e+00, 2.01330151e+01, 2.35880674e+00], [ 3.51861237e+01, -5.10963889e+00, 6.33455270e+00, ..., 1.30060388e+01, 9.25869065e+00, -2.23306866e+00], ..., [ 2.51171200e+01, 4.42316555e+00, 2.11349881e+01, ..., 9.05110605e-02, 1.44065339e+01, 8.20843537e-02], [ 2.71918187e+01, 2.16938885e+01, -5.72348538e+01, ..., 2.12376840e+01, -2.51280109e+01, 2.18291771e+01], [ 1.00724253e+01, 1.75265926e+01, -4.43913067e+00, ..., -7.83367531e+00, 5.94749930e+00, 1.33033673e+01]]])
- created_at :
- 2021-01-20T07:32:44.902611
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (chain: 1, draw: 4000, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (chain, draw, school) float64 -19.44 137.9 18.57 ... 5.947 13.3 Attributes: created_at: 2021-01-20T07:32:44.902611 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset) -
- school: 8
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- y(school)float6428.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
- created_at :
- 2021-01-20T07:32:45.318564
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: y (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2021-01-20T07:32:45.318564 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset) -
- J_dim_0: 1
- school: 8
- J_dim_0(J_dim_0)int640
array([0], dtype=int64)
- school(school)<U16'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
- J(J_dim_0)int648
array([8], dtype=int64)
- σ(school)float6415.0 10.0 16.0 ... 11.0 10.0 18.0
array([15., 10., 16., 11., 9., 11., 10., 18.])
- created_at :
- 2021-01-20T07:32:46.482295
- arviz_version :
- 0.11.0
- inference_library :
- Soss
Dataset (xarray.Dataset) Dimensions: (J_dim_0: 1, school: 8) Coordinates: * J_dim_0 (J_dim_0) int64 0 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: J (J_dim_0) int64 8 σ (school) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0 Attributes: created_at: 2021-01-20T07:32:46.482295 arviz_version: 0.11.0 inference_library: Soss
Dataset (xarray.Dataset)
We can compare the prior and posterior predictive distributions:
plot_density(
[idata.posterior_predictive, idata.prior_predictive];
data_labels=["Post-pred", "Prior-pred"],
var_names=["y"],
)
gcf()
Environment
using Pkg
Pkg.status()
Status `~/work/ArviZ.jl/ArviZ.jl/docs/Project.toml` [131c737c] ArviZ v0.4.11 `~/work/ArviZ.jl/ArviZ.jl` [593b3428] CmdStan v6.1.6 [31c24e10] Distributions v0.23.8 [e30172f5] Documenter v0.26.1 [c7f686f2] MCMCChains v4.4.0 [d9ec5142] NamedTupleTools v0.13.7 [438e738f] PyCall v1.92.2 [d330b81b] PyPlot v2.9.0 [8ce77f84] Soss v0.14.4 [fce5fe82] Turing v0.14.12 [37e2e46d] LinearAlgebra
using InteractiveUtils
versioninfo()
Julia Version 1.5.3 Commit 788b2c77c1 (2020-11-09 13:37 UTC) Platform Info: OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-9.0.1 (ORCJIT, skylake-avx512) Environment: JULIA_CMDSTAN_HOME = /home/runner/work/ArviZ.jl/ArviZ.jl/.cmdstan//cmdstan-2.25.0/ JULIA_NUM_THREADS = 2