ArviZ.jl Quickstart

This quickstart is adapted from ArviZ's Quickstart.

using ArviZ

# ArviZ ships with style sheets!
ArviZ.use_style("arviz-darkgrid")

Get started with plotting

ArviZ.jl is designed to be used with libraries like CmdStan, Turing.jl, and Soss.jl but works fine with raw arrays.

plot_posterior(randn(100_000));

Plotting a dictionary of arrays, ArviZ.jl will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws each for four different distributions.

using Distributions

s = (10, 50)
plot_forest(Dict(
    "normal" => randn(s),
    "gumbel" => rand(Gumbel(), s),
    "student t" => rand(TDist(6), s),
    "exponential" => rand(Exponential(), s)
));

Plotting with MCMCChains.jl's Chains objects produced by Turing.jl

ArviZ is designed to work well with high dimensional, labelled data. Consider the eight schools model, which roughly tries to measure the effectiveness of SAT classes at eight different schools. To show off ArviZ's labelling, I give the schools the names of a different eight schools.

This model is small enough to write down, is hierarchical, and uses labelling. Additionally, a centered parameterization causes divergences (which are interesting for illustration).

First we create our data and set some sampling parameters.

J = 8
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
schools = [
    "Choate",
    "Deerfield",
    "Phillips Andover",
    "Phillips Exeter",
    "Hotchkiss",
    "Lawrenceville",
    "St. Paul's",
    "Mt. Hermon"
];

nwarmup, nsamples, nchains = 1000, 1000, 4

Now we write and run the model using Turing:

using Turing

Turing.@model turing_model(J, y, σ) = begin
    μ ~ Normal(0, 5)
    τ ~ Truncated(Cauchy(0, 5), 0, Inf)
    θ = tzeros(J)
    θ ~ [Normal(μ, τ)]
    y ~ MvNormal(θ, σ)
end

param_mod = turing_model(J, y, σ)
sampler = NUTS(nwarmup, 0.8)
turing_chns = mapreduce(chainscat, 1:nchains) do _
    return sample(param_mod, sampler, nwarmup + nsamples; progress = false)
end;

Most ArviZ functions work fine with Chains objects from Turing:

plot_autocorr(convert_to_inference_data(turing_chns); var_names = ["μ", "τ"]);

Convert to InferenceData

For much more powerful querying, analysis and plotting, we can use built-in ArviZ utilities to convert Chains objects to xarray datasets. Note we are also giving some information about labelling.

ArviZ is built to work with InferenceData (a netcdf datastore that loads data into xarray datasets), and the more groups it has access to, the more powerful analyses it can perform.

idata = from_mcmcchains(
    turing_chns,
    coords = Dict("school" => schools),
    dims = Dict(
        "y" => ["school"],
        "σ" => ["school"],
        "θ" => ["school"],
    ),
    library = "Turing",
)
InferenceData with groups:
	> posterior
	> sample_stats

Each group is an ArviZ.Dataset (a thinly wrapped xarray.Dataset). We can view a summary of the dataset.

idata.posterior
Show/Hide data repr Show/Hide attributes
Dataset (xarray.Dataset)
    • chain: 4
    • draw: 1000
    • school: 8
    • chain
      (chain)
      int64
      0 1 2 3
      array([0, 1, 2, 3])
    • draw
      (draw)
      int64
      0 1 2 3 4 5 ... 995 996 997 998 999
      array([  0,   1,   2, ..., 997, 998, 999])
    • school
      (school)
      <U16
      'Choate' ... 'Mt. Hermon'
      array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter',
             'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
    • μ
      (chain, draw)
      float64
      3.305 2.345 3.058 ... 2.119 5.054
      array([[ 3.30525149,  2.34487813,  3.05829645, ...,  5.24526471,
               4.42846717,  4.522011  ],
             [ 4.51425554,  4.37381908,  5.19878346, ...,  4.24867028,
               3.25986382,  5.51661768],
             [ 1.94356216, -0.93760709,  0.90802825, ...,  4.74029606,
               4.74029606,  4.53333644],
             [ 9.59813474,  8.13472905, 11.97539015, ...,  1.13975013,
               2.11882422,  5.05367953]])
    • τ
      (chain, draw)
      float64
      0.6879 0.6479 ... 3.398 7.237
      array([[ 0.68793368,  0.64787394,  0.72335522, ...,  9.35253698,
              10.22606796,  8.71508376],
             [ 1.01039473,  0.81359069,  0.73373027, ...,  3.09998285,
               2.9734463 ,  2.01517862],
             [ 3.01210659,  4.26721503,  3.5705303 , ...,  0.4347157 ,
               0.4347157 ,  0.48338183],
             [ 4.15118198, 12.16098173,  4.51125597, ...,  6.01806191,
               3.39849132,  7.23737065]])
    • θ
      (chain, draw, school)
      float64
      4.075 2.782 4.191 ... 7.114 9.62
      array([[[  4.07463326,   2.78212647,   4.19106254, ...,   2.48705916,
                 3.50300245,   3.11765717],
              [  2.20080012,   2.09233286,   2.28588542, ...,   3.2186879 ,
                 2.97864441,   3.00008265],
              [  2.42375564,   3.57028963,   2.18843992, ...,   2.71747734,
                 2.93360042,   3.07718092],
              ...,
              [  5.26508338,  10.22488119,  13.96794951, ...,  26.80846881,
                19.08296223,  -5.12777522],
              [ 17.13463737,  -0.92313157,  -5.56431681, ..., -11.01457899,
                 7.41924734,   8.16748985],
              [ 18.96067414,  -2.46395226,   2.35768771, ...,   0.11910993,
                 7.5086826 ,   8.64058053]],
      
             [[  5.49134171,   3.72590692,   4.92247391, ...,   5.61803397,
                 4.99013968,   3.85659634],
              [  4.14224481,   5.2942392 ,   3.98098086, ...,   4.19016882,
                 4.96988281,   4.65551063],
              [  6.4982077 ,   4.884429  ,   5.34114376, ...,   5.10882867,
                 6.23394816,   4.42100915],
              ...,
              [  4.19775567,   7.39836638,  -0.57714207, ...,   9.5670128 ,
                 5.28591259,   5.37721607],
              [  3.24334636,   4.3412533 ,   3.0828176 , ...,   4.90265185,
                 6.68125798,   1.10489579],
              [  6.85403406,   9.62354755,   4.19610664, ...,   7.81227583,
                 4.98228147,   4.55425287]],
      
             [[ -3.85394652,   7.10460484,  -0.11453667, ...,   4.89023084,
                 2.69445254,   1.12789899],
              [  4.82397435,  -1.86379929,   2.92592401, ...,  -4.82744926,
                 4.26803375,  -4.94104109],
              [  0.96721258,   2.51018273,  -4.82227942, ...,  -4.90713838,
                -0.77836404,   6.17446352],
              ...,
              [  3.94477722,   4.68672094,   5.15122283, ...,   5.25054938,
                 4.67435372,   3.90418726],
              [  3.94477722,   4.68672094,   5.15122283, ...,   5.25054938,
                 4.67435372,   3.90418726],
              [  4.79629317,   4.23483253,   4.56261521, ...,   4.52719682,
                 4.63335229,   4.13306957]],
      
             [[  9.83148843,  18.34128265,  -0.93115828, ...,   3.99997258,
                11.58747961,   3.56340506],
              [ 13.26363403,   8.45128714,  16.40603378, ...,  10.4817822 ,
                30.55686372,  12.19498037],
              [ 17.50248014,   4.67788023,   6.80013073, ...,   5.63071409,
                 9.56286534,  15.04280675],
              ...,
              [ 10.4192724 ,   8.7529961 ,   6.16215804, ...,  -3.56199399,
                 0.86367106,  -0.95984589],
              [ -0.57859673,  -0.69818652,   1.52007345, ...,   5.46174366,
                 6.19511594,   2.29468612],
              [ 16.72362679,  13.23662592,   8.83607735, ...,   4.84823178,
                 7.11365238,   9.61956278]]])
  • created_at :
    2019-12-05T03:47:33.485250
    inference_library :
    Turing

Here is a plot of the trace. Note the intelligent labels.

plot_trace(idata);

We can also generate summary stats

summarystats(idata)
       mean     sd  hpd_3%  hpd_97%  ...  ess_sd  ess_bulk  ess_tail  r_hat
μ     4.377  3.083  -1.667   10.020  ...   538.0     538.0    1076.0   1.01
τ     3.955  3.338   0.415    9.964  ...   149.0      92.0      78.0   1.03
θ[1]  6.425  6.031  -3.742   18.191  ...   754.0     694.0    1390.0   1.01
θ[2]  4.970  4.820  -3.771   14.783  ...   748.0    1096.0    1657.0   1.00
θ[3]  3.871  5.398  -6.413   14.506  ...  1078.0    1129.0    1537.0   1.01
θ[4]  4.819  4.756  -3.743   14.583  ...  1042.0     982.0    2188.0   1.01
θ[5]  3.517  4.615  -4.736   12.860  ...   941.0     977.0    1473.0   1.01
θ[6]  3.949  4.653  -5.486   12.415  ...  1102.0    1153.0    1838.0   1.00
θ[7]  6.514  4.874  -1.881   16.258  ...   390.0     413.0    1074.0   1.01
θ[8]  4.954  5.314  -4.265   15.910  ...   806.0    1075.0    1557.0   1.01

[10 rows x 11 columns]

and examine the energy distribution of the Hamiltonian sampler

plot_energy(idata);

Plotting with CmdStan.jl outputs

CmdStan.jl and StanSample.jl also default to producing Chains outputs, and we can easily plot these chains.

Here is the same centered eight schools model:

using CmdStan

schools_code = """
data {
  int<lower=0> J;
  real y[J];
  real<lower=0> sigma[J];
}

parameters {
  real mu;
  real<lower=0> tau;
  real theta[J];
}

model {
  mu ~ normal(0, 5);
  tau ~ cauchy(0, 5);
  theta ~ normal(mu, tau);
  y ~ normal(theta, sigma);
}

generated quantities {
    vector[J] log_lik;
    vector[J] y_hat;
    for (j in 1:J) {
        log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
        y_hat[j] = normal_rng(theta[j], sigma[j]);
    }
}
"""

schools_dat = Dict("J" => J, "y" => y, "sigma" => σ)
stan_model = Stanmodel(
    model = schools_code,
    nchains = nchains,
    num_warmup = nwarmup,
    num_samples = nsamples,
)
_, stan_chns, _ = stan(stan_model, schools_dat, summary = false);

File /home/travis/build/arviz-devs/ArviZ.jl/docs/build/tmp/noname.stan will be updated.
plot_density(convert_to_inference_data(stan_chns); var_names=["mu", "tau"]);

Again, converting to InferenceData, we can get much richer labelling and mixing of data. Note that we're using the same from_cmdstan function used by ArviZ to process cmdstan output files, but through the power of dispatch in Julia, if we pass a Chains object, it instead uses ArviZ.jl's overloads, which forward to from_mcmcchains.

idata = from_cmdstan(
    stan_chns;
    posterior_predictive = "y_hat",
    observed_data = Dict("y" => schools_dat["y"]),
    log_likelihood = "log_lik",
    coords = Dict("school" => schools),
    dims = Dict(
        "y" => ["school"],
        "sigma" => ["school"],
        "theta" => ["school"],
        "log_lik" => ["school"],
        "y_hat" => ["school"],
    ),
)
InferenceData with groups:
	> posterior
	> posterior_predictive
	> sample_stats
	> observed_data

Here is a plot showing where the Hamiltonian sampler had divergences:

plot_pair(
    idata;
    coords = Dict("school" => ["Choate", "Deerfield", "Phillips Andover"]),
    divergences = true,
);

Plotting with Soss.jl outputs

With Soss, we can define our model for the posterior and easily use it to draw samples from the prior, prior predictive, posterior, and posterior predictive distributions.

First we define our model:

using Soss, NamedTupleTools

mod = Soss.@model (J, σ) begin
    μ ~ Normal(0, 5)
    τ ~ HalfCauchy(5)
    θ ~ Normal(μ, τ) |> iid(J)
    y ~ For(J) do j
        Normal(θ[j], σ[j])
    end
end

constant_data = (J = J, σ = σ)
param_mod = mod(; constant_data...)
Joint Distribution
    Bound arguments: [J, σ]
    Variables: [τ, μ, θ, y]

@model (J, σ) begin
        τ ~ HalfCauchy(5)
        μ ~ Normal(0, 5)
        θ ~ Normal(μ, τ) |> iid(J)
        y ~ For(J) do j
                Normal(θ[j], σ[j])
            end
    end

Then we draw from the prior and prior predictive distributions.

prior_prior_pred = map(1:nchains*nsamples) do _
    draw = rand(param_mod)
    return delete(draw, keys(constant_data))
end

prior = map(draw -> delete(draw, :y), prior_prior_pred)
prior_pred = map(draw -> delete(draw, (:μ, :τ, :θ)), prior_prior_pred);

Next, we draw from the posterior using DynamicHMC.jl.

post = map(1:nchains) do _
    dynamicHMC(param_mod, (y = y,), logpdf, nsamples)
end;

Finally, we use the posterior samples to draw from the posterior predictive distribution.

pred = predictive(mod, :μ, :τ, :θ)
post_pred = map(post) do post_draws
    map(post_draws) do post_draw
        pred_draw = rand(pred(post_draw)(constant_data))
        return delete(pred_draw, keys(constant_data))
    end
end;

Each Soss draw is a NamedTuple. Now we combine all of the samples to an InferenceData:

idata = from_namedtuple(
    post;
    posterior_predictive = post_pred,
    prior = [prior],
    prior_predictive = [prior_pred],
    observed_data = Dict("y" => y),
    constant_data = constant_data,
    coords = Dict("school" => schools),
    dims = Dict(
        "y" => ["school"],
        "σ" => ["school"],
        "θ" => ["school"],
    ),
    library = Soss,
)
InferenceData with groups:
	> posterior
	> posterior_predictive
	> prior
	> prior_predictive
	> observed_data
	> constant_data

We can compare the prior and posterior predictive distributions:

plot_density(
    [idata.posterior_predictive, idata.prior_predictive];
    data_labels = ["Post-pred", "Prior-pred"],
    var_names = ["y"],
)