ArviZ.jl Quickstart

Note

This tutorial is adapted from ArviZ's quickstart.

using ArviZ
using PyPlot

# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")

Get started with plotting

ArviZ.jl is designed to be used with libraries like CmdStan, Turing.jl, and Soss.jl but works fine with raw arrays.

using Random

rng = Random.MersenneTwister(37772)
plot_posterior(randn(rng, 100_000));
gcf()

Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.

using Distributions

s = (10, 50)
plot_forest(
    Dict(
        "normal" => randn(rng, s),
        "gumbel" => rand(rng, Gumbel(), s),
        "student t" => rand(rng, TDist(6), s),
        "exponential" => rand(rng, Exponential(), s),
    ),
);
gcf()

Plotting with MCMCChains.jl's Chains objects produced by Turing.jl

ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.

This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).

First we create our data and set some sampling parameters.

J = 8
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
schools = [
    "Choate",
    "Deerfield",
    "Phillips Andover",
    "Phillips Exeter",
    "Hotchkiss",
    "Lawrenceville",
    "St. Paul's",
    "Mt. Hermon",
];

nwarmup, nsamples, nchains = 1000, 1000, 4;

Now we write and run the model using Turing:

using Turing

Turing.@model function turing_model(y, σ, J=length(y))
    μ ~ Normal(0, 5)
    τ ~ truncated(Cauchy(0, 5), 0, Inf)
    θ ~ filldist(Normal(μ, τ), J)
    for i in 1:J
        y[i] ~ Normal(θ[i], σ[i])
    end
end

param_mod = turing_model(y, σ)
sampler = NUTS(nwarmup, 0.8)

rng = Random.MersenneTwister(16653)
turing_chns = sample(
    rng, param_mod, sampler, MCMCThreads(), nsamples, nchains; progress=false
);
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Info: Found initial step size
└   ϵ = 0.8
┌ Info: Found initial step size
└   ϵ = 0.30000000000000004
┌ Info: Found initial step size
└   ϵ = 0.4

Most ArviZ functions work fine with Chains objects from Turing:

plot_autocorr(turing_chns; var_names=["μ", "τ"]);
gcf()

Convert to InferenceData

For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains objects to xarray datasets. Note we are also giving some information about labelling.

ArviZ is built to work with InferenceData (a netcdf datastore that loads data into xarray datasets), and the more groups it has access to, the more powerful analyses it can perform.

idata = from_mcmcchains(
    turing_chns;
    coords=Dict("school" => schools),
    dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
    library="Turing",
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061
          τ        (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185
          θ        (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008
      Attributes:
          created_at:         2021-05-11T17:20:41.996609
          arviz_version:      0.11.2
          start_time:         1620753599.753838
          stop_time:          1620753617.297758
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:           (chain: 4, draw: 1000)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          energy            (chain, draw) float64 49.74 48.67 52.26 ... 61.18 57.5
          energy_error      (chain, draw) float64 -0.1874 0.05797 ... -0.02117 0.02483
          tree_depth        (chain, draw) int64 3 3 3 4 5 2 3 3 2 ... 4 4 4 4 3 4 4 3
          diverging         (chain, draw) bool False False False ... False False False
          step_size_nom     (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542
          acceptance_rate   (chain, draw) float64 0.9568 0.9265 ... 0.9876 0.991
          log_density       (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9
          max_energy_error  (chain, draw) float64 -0.8297 0.1387 ... -0.2598 -0.5436
          is_accept         (chain, draw) bool True True True True ... True True True
          lp                (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9
          step_size         (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542
          n_steps           (chain, draw) int64 15 7 7 31 47 7 ... 15 31 15 15 15 15
      Attributes:
          created_at:         2021-05-11T17:20:42.032613
          arviz_version:      0.11.2
          start_time:         1620753599.753838
          stop_time:          1620753617.297758
          inference_library:  Turing

Each group is an ArviZ.Dataset (a thinly wrapped xarray.Dataset). We can view a summary of the dataset.

idata.posterior
Dataset (xarray.Dataset)
Dimensions:  (chain: 4, draw: 1000, school: 8)
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
Data variables:
    μ        (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061
    τ        (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185
    θ        (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008
Attributes:
    created_at:         2021-05-11T17:20:41.996609
    arviz_version:      0.11.2
    start_time:         1620753599.753838
    stop_time:          1620753617.297758
    inference_library:  Turing

Here is a plot of the trace. Note the intelligent labels.

plot_trace(idata);
gcf()

We can also generate summary stats

summarystats(idata)

10 rows × 10 columns

variablemeansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
StringFloat64Float64Float64Float64Float64Float64Float64Float64Float64
1μ4.3773.433-2.2910.6850.1260.092743.0881.01.01
2τ3.6493.00.4078.9220.1840.13161.0106.01.01
3θ[1]6.2335.668-3.11917.3240.2060.145787.01324.01.01
4θ[2]4.9684.668-3.79513.9390.1380.1051087.01834.01.01
5θ[3]3.9455.228-5.88513.9050.150.1131126.01586.01.01
6θ[4]4.6674.889-4.77913.7320.1470.1151042.01677.01.0
7θ[5]3.5424.797-5.58612.7660.1350.1171192.01511.01.01
8θ[6]4.0444.768-4.80713.5810.1450.1081024.01348.01.01
9θ[7]6.2965.07-2.37116.5120.1890.133727.01258.01.01
10θ[8]4.7665.289-6.10414.1860.1540.1271047.01611.01.01

and examine the energy distribution of the Hamiltonian sampler

plot_energy(idata);
gcf()

Additional information in Turing.jl

With a few more steps, we can use Turing to compute additional useful groups to add to the InferenceData.

To sample from the prior, one simply calls sample but with the Prior sampler:

prior = sample(param_mod, Prior(), nsamples; progress=false)
Chains MCMC chain (1000×11×1 Array{Float64, 3}):

Start time        = 2021-05-11T17:20:52.494
Stop time         = 2021-05-11T17:20:52.760
Wall duration     = 0.27 seconds
Iterations        = 1:1000
Thinning interval = 1
Chains            = 1
Samples per chain = 1000
parameters        = θ[1], θ[2], θ[3], θ[4], θ[5], θ[6], θ[7], θ[8], μ, τ
internals         = lp

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat    ⋯
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64    ⋯

        θ[1]    0.9495   67.9183     2.1478    1.8926    556.5266    1.0001    ⋯
        θ[2]   -0.6053   48.9281     1.5472    1.0120   1024.0944    1.0000    ⋯
        θ[3]   -1.1761   60.8375     1.9238    1.5571    999.5217    0.9991    ⋯
        θ[4]    0.1684   52.6024     1.6634    1.8087    811.6064    1.0025    ⋯
        θ[5]    1.7505   71.2369     2.2527    2.2212    918.8990    0.9993    ⋯
        θ[6]    2.0359   74.2804     2.3490    2.5844   1004.3982    1.0023    ⋯
        θ[7]    1.1503   53.3496     1.6871    1.6193   1044.4259    0.9991    ⋯
        θ[8]    0.7439   58.6209     1.8538    1.9887    965.1255    1.0011    ⋯
           μ    0.0522    4.8552     0.1535    0.1401   1030.5527    1.0003    ⋯
           τ   18.2545   68.3487     2.1614    2.4279    638.0227    0.9996    ⋯
                                                                1 column omitted

Quantiles
  parameters       2.5%     25.0%     50.0%     75.0%      97.5%
      Symbol    Float64   Float64   Float64   Float64    Float64

        θ[1]   -53.1021   -5.5494    0.0282    5.4924    42.0920
        θ[2]   -47.9095   -4.9937    0.2079    6.0771    38.2048
        θ[3]   -48.4412   -5.0144    0.3014    5.8416    39.0567
        θ[4]   -36.2510   -4.9584    0.4786    6.0296    46.7597
        θ[5]   -41.1661   -5.4494   -0.3680    5.0671    53.3500
        θ[6]   -40.9711   -5.4135    0.2481    5.6112    47.2051
        θ[7]   -39.5905   -5.2035    0.4159    6.4169    50.1237
        θ[8]   -49.2875   -4.9863    0.2812    5.8763    55.9611
           μ    -9.7057   -3.3349   -0.0951    3.3130     9.5634
           τ     0.1994    2.0832    4.9683   11.2756   112.9419

To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to missing, and then calling predict on the predictive model and the previously drawn samples:

# Instantiate the predictive model
param_mod_predict = turing_model(similar(y, Missing), σ)
# and then sample!
prior_predictive = predict(param_mod_predict, prior)
posterior_predictive = predict(param_mod_predict, turing_chns)
Chains MCMC chain (1000×8×4 Array{Float64, 3}):

Iterations        = 1:1000
Thinning interval = 1
Chains            = 1, 2, 3, 4
Samples per chain = 1000
parameters        = y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8]
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64

        y[1]    6.1940   15.8365     0.2504    0.3187   2598.9319    0.9996
        y[2]    5.1315   11.0163     0.1742    0.2020   2972.2101    1.0004
        y[3]    4.1674   16.8618     0.2666    0.2583   3529.7049    1.0003
        y[4]    4.7150   12.1285     0.1918    0.2096   2859.1765    0.9996
        y[5]    3.5240   10.2518     0.1621    0.2029   2745.2054    1.0010
        y[6]    4.1669   11.9747     0.1893    0.2232   3018.5444    0.9996
        y[7]    6.5947   11.0310     0.1744    0.2484   2290.3377    1.0000
        y[8]    5.0596   18.4197     0.2912    0.2796   3201.2609    0.9999

Quantiles
  parameters       2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol    Float64   Float64   Float64   Float64   Float64

        y[1]   -25.4038   -4.0195    6.3500   16.6845   37.8474
        y[2]   -16.5638   -2.5435    5.3023   12.6178   26.2114
        y[3]   -29.6595   -7.0289    4.1721   15.3849   37.5304
        y[4]   -19.2035   -3.6940    4.5058   12.9910   28.7925
        y[5]   -16.5995   -3.2405    3.5815   10.3291   23.3610
        y[6]   -18.4540   -3.8768    4.2179   12.0992   27.4589
        y[7]   -14.3394   -0.7633    6.5626   13.9443   28.7613
        y[8]   -30.7903   -7.2315    5.0240   17.7800   41.0833

And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as loo,

loglikelihoods = Turing.pointwise_loglikelihoods(
    param_mod, MCMCChains.get_sections(turing_chns, :parameters)
)
Dict{String, Matrix{Float64}} with 8 entries:
  "y[6]" => [-3.38973 -3.51189 -3.43534 -3.73188; -3.32076 -3.32165 -3.38096 -3…
  "y[2]" => [-3.27342 -3.22381 -3.30609 -3.35272; -3.36404 -3.42648 -3.23499 -3…
  "y[1]" => [-4.92564 -4.6233 -4.15682 -4.81507; -4.94726 -5.12105 -5.63477 -3.…
  "y[5]" => [-3.17195 -3.55913 -3.55442 -3.64785; -3.34779 -3.17499 -3.21129 -3…
  "y[8]" => [-3.92493 -3.82086 -4.23278 -3.88106; -3.9033 -3.92951 -4.07239 -3.…
  "y[7]" => [-4.20172 -3.96703 -3.49247 -4.77806; -4.1926 -4.5937 -4.66271 -3.4…
  "y[3]" => [-3.75618 -3.87783 -3.69645 -4.18097; -3.80406 -3.74826 -3.70795 -3…
  "y[4]" => [-3.38347 -3.32362 -3.32345 -3.31684; -3.33491 -3.50791 -3.35802 -3…

This can then be included in the from_mcmcchains call from above:

using LinearAlgebra
# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
ynames = string.(keys(posterior_predictive))
loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames)
# Reshape into `(nchains, nsamples, size(y)...)`
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3))

idata = from_mcmcchains(
    turing_chns;
    posterior_predictive=posterior_predictive,
    log_likelihood=Dict("y" => loglikelihoods_arr),
    prior=prior,
    prior_predictive=prior_predictive,
    observed_data=Dict("y" => y),
    coords=Dict("school" => schools),
    dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
    library="Turing",
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 3.598 3.844 3.727 1.557 ... 7.414 3.701 3.061
          τ        (chain, draw) float64 1.341 1.253 0.9012 ... 2.304 4.701 3.185
          θ        (chain, draw, school) float64 3.826 4.778 2.754 ... 5.444 7.008
      Attributes:
          created_at:         2021-05-11T17:21:19.083858
          arviz_version:      0.11.2
          start_time:         1620753599.753838
          stop_time:          1620753617.297758
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 37.05 16.89 6.065 ... 14.09 6.818
      Attributes:
          created_at:         2021-05-11T17:21:18.582580
          arviz_version:      0.11.2
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -4.926 -3.273 -3.756 ... -4.01 -3.848
      Attributes:
          created_at:         2021-05-11T17:21:19.008683
          arviz_version:      0.11.2
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:           (chain: 4, draw: 1000)
      Coordinates:
        * chain             (chain) int64 0 1 2 3
        * draw              (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          energy            (chain, draw) float64 49.74 48.67 52.26 ... 61.18 57.5
          energy_error      (chain, draw) float64 -0.1874 0.05797 ... -0.02117 0.02483
          tree_depth        (chain, draw) int64 3 3 3 4 5 2 3 3 2 ... 4 4 4 4 3 4 4 3
          diverging         (chain, draw) bool False False False ... False False False
          step_size_nom     (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542
          acceptance_rate   (chain, draw) float64 0.9568 0.9265 ... 0.9876 0.991
          log_density       (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9
          max_energy_error  (chain, draw) float64 -0.8297 0.1387 ... -0.2598 -0.5436
          is_accept         (chain, draw) bool True True True True ... True True True
          lp                (chain, draw) float64 -46.53 -46.78 ... -55.97 -54.9
          step_size         (chain, draw) float64 0.1222 0.1288 ... 0.1542 0.1542
          n_steps           (chain, draw) int64 15 7 7 31 47 7 ... 15 31 15 15 15 15
      Attributes:
          created_at:         2021-05-11T17:21:19.089630
          arviz_version:      0.11.2
          start_time:         1620753599.753838
          stop_time:          1620753617.297758
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 4.591 -3.59 7.95 ... -7.285 -2.541 -5.66
          τ        (chain, draw) float64 3.726 4.819 0.8977 ... 1.353 9.565 7.004
          θ        (chain, draw, school) float64 2.823 1.1 7.183 ... -2.011 -7.082
      Attributes:
          created_at:         2021-05-11T17:21:19.854776
          arviz_version:      0.11.2
          start_time:         1620753652.493586
          stop_time:          1620753652.760436
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -19.08 -8.769 22.26 ... 15.77 -8.977
      Attributes:
          created_at:         2021-05-11T17:21:19.709455
          arviz_version:      0.11.2
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 1000)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
      Data variables:
          lp       (chain, draw) float64 -55.65 -65.73 -43.9 ... -55.26 -71.98 -66.78
      Attributes:
          created_at:         2021-05-11T17:21:19.888505
          arviz_version:      0.11.2
          start_time:         1620753652.493586
          stop_time:          1620753652.760436
          inference_library:  Turing

    • Dataset (xarray.Dataset)
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:         2021-05-11T17:21:21.328132
          arviz_version:      0.11.2
          inference_library:  Turing

Then we can for example compute the expected leave-one-out (LOO) predictive density, which is an estimate of the out-of-distribution predictive fit of the model:

loo(idata) # higher is better

1 rows × 7 columns

looloo_sep_loon_samplesn_data_pointswarningloo_scale
Float64Float64Float64Int64Int64BoolString
1-30.7071.35690.861043400080log

If the model is well-calibrated, i.e. it replicates the true generative process well, the CDF of the pointwise LOO values should be similarly distributed to a uniform distribution. This can be inspected visually:

plot_loo_pit(idata; y="y", ecdf=true);
gcf()

Plotting with CmdStan.jl outputs

CmdStan.jl and StanSample.jl also default to producing Chains outputs, and we can easily plot these chains.

Here is the same centered eight schools model:

using CmdStan, MCMCChains

schools_code = """
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}

generated quantities {
    vector[J] log_lik;
    vector[J] y_hat;
    for (j in 1:J) {
        log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
        y_hat[j] = normal_rng(theta[j], sigma[j]);
    }
}
"""

schools_dat = Dict("J" => J, "y" => y, "sigma" => σ)
stan_model = Stanmodel(;
    model=schools_code,
    name="schools",
    nchains=nchains,
    num_warmup=nwarmup,
    num_samples=nsamples,
    output_format=:mcmcchains,
    random=CmdStan.Random(28983),
)
_, stan_chns, _ = stan(stan_model, schools_dat; summary=false);
File /home/runner/work/ArviZ.jl/ArviZ.jl/docs/build/tmp/schools.stan will be updated.
plot_density(stan_chns; var_names=["mu", "tau"]);
gcf()

Again, converting to InferenceData, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains.

idata = from_cmdstan(
    stan_chns;
    posterior_predictive="y_hat",
    observed_data=Dict("y" => schools_dat["y"]),
    log_likelihood="log_lik",
    coords=Dict("school" => schools),
    dims=Dict(
        "y" => ["school"],
        "sigma" => ["school"],
        "theta" => ["school"],
        "log_lik" => ["school"],
        "y_hat" => ["school"],
    ),
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          theta    (chain, draw, school) float64 6.541 4.165 1.673 ... 7.588 7.146
          tau      (chain, draw) float64 2.026 3.44 2.047 ... 0.5002 0.5771 0.5771
          mu       (chain, draw) float64 5.864 4.849 3.906 5.24 ... 7.369 7.617 7.617
      Attributes:
          created_at:         2021-05-11T17:22:01.518920
          arviz_version:      0.11.2
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y_hat    (chain, draw, school) float64 -10.71 23.66 -9.425 ... 9.078 20.88
      Attributes:
          created_at:         2021-05-11T17:22:01.487547
          arviz_version:      0.11.2
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          log_lik  (chain, draw, school) float64 -4.65 -3.295 -3.734 ... -3.764 -3.846
      Attributes:
          created_at:         2021-05-11T17:22:01.497752
          arviz_version:      0.11.2
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 0 1 2 3
        * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          tree_depth       (chain, draw) int64 4 2 4 3 3 3 4 4 4 ... 1 1 2 1 2 1 5 3 2
          diverging        (chain, draw) bool False False False ... False True False
          energy           (chain, draw) float64 21.85 17.36 18.12 ... 11.01 16.54
          lp               (chain, draw) float64 -11.6 -14.53 -12.58 ... -6.09 -6.09
          step_size        (chain, draw) float64 0.2012 0.2012 ... 0.1493 0.1493
          acceptance_rate  (chain, draw) float64 0.9582 0.8976 ... 0.04874 1.587e-05
          n_steps          (chain, draw) int64 15 7 15 15 7 15 15 ... 5 1 5 3 31 9 3
      Attributes:
          created_at:         2021-05-11T17:22:01.533252
          arviz_version:      0.11.2
          inference_library:  CmdStan

    • Dataset (xarray.Dataset)
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:         2021-05-11T17:22:01.558449
          arviz_version:      0.11.2
          inference_library:  CmdStan

Here is a plot showing where the Hamiltonian sampler had divergences:

plot_pair(
    idata;
    coords=Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
    divergences=true,
);
gcf()

Plotting with Soss.jl outputs

With Soss, we can define our model for the posterior and easily use it to draw samples from the prior, prior predictive, posterior, and posterior predictive distributions.

First we define our model:

using Soss, NamedTupleTools

mod = Soss.@model (J, σ) begin
    μ ~ Normal(0, 5)
    τ ~ HalfCauchy(5)
    θ ~ iid(J)(Normal(μ, τ))
    y ~ For(1:J) do j
        Normal(θ[j], σ[j])
    end
end

constant_data = (J=J, σ=σ)
param_mod = mod(; constant_data...)
Joint Distribution
    Bound arguments: [J, σ]
    Variables: [τ, μ, θ, y]

@model (J, σ) begin
        τ ~ HalfCauchy(5)
        μ ~ Normal(0, 5)
        θ ~ (iid(J))(Normal(μ, τ))
        y ~ For(1:J) do j
                Normal(θ[j], σ[j])
            end
    end

Then we draw from the prior and prior predictive distributions.

Random.seed!(5298)
prior_priorpred = [
    map(1:(nchains * nsamples)) do _
        draw = rand(param_mod)
        return delete(draw, keys(constant_data))
    end,
];

Next, we draw from the posterior using DynamicHMC.jl.

post = map(1:nchains) do _
    dynamicHMC(param_mod, (y=y,), nsamples)
end;

Finally, we update the posterior samples with draws from the posterior predictive distribution.

pred = predictive(mod, :μ, :τ, :θ)
post_postpred = map(post) do post_draws
    map(post_draws) do post_draw
        pred_draw = rand(pred(post_draw))
        pred_draw = delete(pred_draw, keys(constant_data))
        return merge(pred_draw, post_draw)
    end
end;

Each Soss draw is a NamedTuple. We can plot the rank order statistics of the posterior to identify poor convergence:

plot_rank(post; var_names=["μ", "τ"]);
gcf()

Now we combine all of the samples to an InferenceData:

idata = from_namedtuple(
    post_postpred;
    posterior_predictive=[:y],
    prior=prior_priorpred,
    prior_predictive=[:y],
    observed_data=(y=y,),
    constant_data=constant_data,
    coords=Dict("school" => schools),
    dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
    library=Soss,
)
InferenceData
    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 4.192 9.221 0.8525 ... 10.62 10.42 8.331
          τ        (chain, draw) float64 11.95 12.96 15.19 7.62 ... 1.135 1.388 1.072
          θ        (chain, draw, school) float64 32.52 -0.5799 8.606 ... 7.977 6.837
      Attributes:
          created_at:         2021-05-11T17:23:14.210574
          arviz_version:      0.11.2
          inference_library:  Soss

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 4, draw: 1000, school: 8)
      Coordinates:
        * chain    (chain) int64 0 1 2 3
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 17.84 1.985 15.04 ... 2.03 14.85
      Attributes:
          created_at:         2021-05-11T17:23:14.172903
          arviz_version:      0.11.2
          inference_library:  Soss

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 4000, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          μ        (chain, draw) float64 -1.058 6.136 -0.3852 ... 1.484 -10.82 -0.7505
          τ        (chain, draw) float64 149.7 0.02342 3.398 ... 4.347 84.28 4.543
          θ        (chain, draw, school) float64 -8.511 154.9 44.86 ... -2.911 -6.732
      Attributes:
          created_at:         2021-05-11T17:23:15.317494
          arviz_version:      0.11.2
          inference_library:  Soss

    • Dataset (xarray.Dataset)
      Dimensions:  (chain: 1, draw: 4000, school: 8)
      Coordinates:
        * chain    (chain) int64 0
        * draw     (draw) int64 0 1 2 3 4 5 6 7 ... 3993 3994 3995 3996 3997 3998 3999
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (chain, draw, school) float64 -19.44 137.9 18.57 ... 5.947 13.3
      Attributes:
          created_at:         2021-05-11T17:23:15.314620
          arviz_version:      0.11.2
          inference_library:  Soss

    • Dataset (xarray.Dataset)
      Dimensions:  (school: 8)
      Coordinates:
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          y        (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
      Attributes:
          created_at:         2021-05-11T17:23:16.068613
          arviz_version:      0.11.2
          inference_library:  Soss

    • Dataset (xarray.Dataset)
      Dimensions:  (J_dim_0: 1, school: 8)
      Coordinates:
        * J_dim_0  (J_dim_0) int64 0
        * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'
      Data variables:
          J        (J_dim_0) int64 8
          σ        (school) float64 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
      Attributes:
          created_at:         2021-05-11T17:23:17.894044
          arviz_version:      0.11.2
          inference_library:  Soss

We can compare the prior and posterior predictive distributions:

plot_density(
    [idata.posterior_predictive, idata.prior_predictive];
    data_labels=["Post-pred", "Prior-pred"],
    var_names=["y"],
)
gcf()

Environment

using Pkg
Pkg.status()
using InteractiveUtils
versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, broadwell)
Environment:
  JULIA_CMDSTAN_HOME = /home/runner/work/ArviZ.jl/ArviZ.jl/.cmdstan//cmdstan-2.25.0/
  JULIA_NUM_THREADS = 2