Getting started

First we define a model with Turing.jl.

using Turing

# Example model.
@model function demo()
    x ~ Normal()
    y ~ Normal(x, 1)
end

model = demo() | (y = 2.0, )
DynamicPPL.Model{typeof(Main.demo), (), (), (), Tuple{}, Tuple{}, DynamicPPL.ConditionContext{NamedTuple{(:y,), Tuple{Float64}}, DynamicPPL.DefaultContext}}(Main.demo, NamedTuple(), NamedTuple(), ConditionContext((y = 2.0,), DynamicPPL.DefaultContext()))
Warning

TuringABC currently only supports conditioning of the form model | (...) or condition(model, ...). That is, passing conditioned variables as inputs to the model is NOT supported (yet).

Let's sample with NUTS first to have something to compare to

samples_nuts = sample(model, NUTS(), 1_000)
Chains MCMC chain (1000×13×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 10.19 seconds
Compute duration  = 10.19 seconds
parameters        = x
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           x    0.9683    0.7136    0.0306   546.3916   695.1337    1.0063     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -0.4249    0.4885    0.9908    1.4668    2.3481

Now we do ABC:

using TuringABC

spl = ABC(0.1)
samples = sample(model, spl, 10_000; chain_type=MCMCChains.Chains)
Chains MCMC chain (10000×2×1 Array{Float64, 3}):

Iterations        = 1:1:10000
Number of chains  = 1
Samples per chain = 10000
parameters        = x
internals         = threshold

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

           x    1.0230    0.7146    0.0370   379.8197   289.8048    1.0038     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x   -0.2883    0.5219    1.0434    1.4851    2.4741

More complex example

Now we're going to try something a bit more crazy: we'll run inference within a model outer_model, and then run inference over this! Yes, you read that right: inference-within-inference.

Warning

This is not something we recommend doing; this is just a demo of what one could do!

using Turing, TuringABC, LinearAlgebra, Logging
using Turing.DynamicPPL

First we define the inner_model, i.e. the model we're going to do inference over within outer_model.

@model function inner_model(σ², N)
    x ~ MvNormal(zeros(N), I)
    y_inner ~ MvNormal(x, σ² * I)
end
inner_model (generic function with 2 methods)

Then we need a method which can convert the resulting approximation of the posterior of inner_model into some statistics that we can use as "observation" for the approximate posterior:

function f_default(samples::MCMCChains.Chains)
    # Use quantiles of the "posterior" (approximated by `samples`)
    return vec(mapreduce(Base.Fix2(quantile, [0.25, 0.5, 0.75]), hcat, eachcol(Array(samples))))
end
f_default (generic function with 1 method)

Now we can finally define the outer_model!

@model function outer_model(
    μ;
    f=f_default,
    # Sampler and number of samples for the inner model.
    # We'll use NUTS by default, but this will be expensive!
    inner_sampler=NUTS(),
    num_inner_samples=1000,
)
    N = length(μ)
    # Prior on the variance used.
    σ² ~ InverseGamma(2, 1)
    # Prior on the mean used.
    y ~ MvNormal(μ, σ² * I)
    # Obtain (approximate) posterior of the inner model conditioned
    # on the sampled `y` from above.
    inner_mdl = inner_model(σ², N) | (y_inner = y,)
    # Turn off logging for this inner sample since it will be called many times.
    posterior = with_logger(NullLogger()) do
        sample(inner_mdl, inner_sampler, num_inner_samples; chain_type=MCMCChains.Chains, progress=false)
    end
    # Since we're now working with an empirical approximation of the
    # posterior, we project `posterior` (usually samples) onto some statistics
    # using `f`, and then this we'll fix/condition to some value later.
    stat ~ DiracDelta(f(posterior))

    return (; posterior, stat)
end
outer_model (generic function with 2 methods)
# Let's generate some data.
μ = zeros(2)
model = outer_model(μ)

vars_true = (σ² = 1.0, y = 0.5 .* ones(length(μ)))
stat_true = rand(condition(model, vars_true)).stat
6-element Vector{Float64}:
 -0.17079086740358163
  0.2291551754806787
  0.6912168383651762
 -0.2207447523678946
  0.2347794183863109
  0.685222740572736
# Now condition the model on the true statistic.
conditioned_model = model | (stat = stat_true,)
# Now if we sample from it there is no `stat`.
rand(conditioned_model)
(σ² = 0.25736572671285435, y = [-0.27275228378392236, -0.6898432769657111])
# We can now use ABC to sample.
# NOTE: This will take a few minutes to run since we're running NUTS in every ABC iteration.
chain = sample(
    conditioned_model,
    ABC(0.1),
    1000;
    discard_initial=1000,
    chain_type=MCMCChains.Chains,
    progress=true
)
Chains MCMC chain (1000×4×1 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 1
Samples per chain = 1000
parameters        = σ², y[1], y[2]
internals         = threshold

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

          σ²    1.0671    1.6403    0.1411   122.9034   135.7674    1.0289     ⋯
        y[1]    0.4492    0.6101    0.0469   188.8879   175.5196    1.0068     ⋯
        y[2]    0.4578    0.8644    0.0780   225.4631   114.8957    1.0014     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

          σ²    0.1659    0.3321    0.5991    0.9439    6.0716
        y[1]   -0.2641    0.0527    0.3367    0.6024    2.3366
        y[2]   -0.2520    0.0465    0.2925    0.6129    2.8378
quantile(chain)
Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

          σ²    0.1659    0.3321    0.5991    0.9439    6.0716
        y[1]   -0.2641    0.0527    0.3367    0.6024    2.3366
        y[2]   -0.2520    0.0465    0.2925    0.6129    2.8378
using StatsPlots
plot(chain)

This is clearly not working very well:) But hey, at least it's possible!