Regression with Student-t noise
You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.
We load all the necessary packages
using AbstractGPs
using ApproximateGPs
using AugmentedGPLikelihoods
using Distributions
using LinearAlgebra
Plotting libraries
using Plots
We create some random data (sorted for plotting reasons)
N = 100
x = range(-10, 10; length=N)
kernel = with_lengthscale(SqExponentialKernel(), 2.0)
gp = GP(kernel)
ν = 3.5
σ = 2.0
lik = StudentTLikelihood(ν, σ)
lf = LatentGP(gp, lik, 1e-6)
f, y = rand(lf(x));
We plot the sampled data
plt = scatter(x, y; label="Data")
plot!(plt, x, f; color=:red, label="Latent GP")
CAVI Updates
We write our CAVI algorithmm
function u_posterior(fz, m, S)
return posterior(SparseVariationalApproximation(Centered(), fz, MvNormal(m, S)))
end
function cavi!(fz::AbstractGPs.FiniteGP, x, y, m, S, qΩ; niter=10)
K = ApproximateGPs._chol_cov(fz)
for _ in 1:niter
post_u = u_posterior(fz, m, S)
post_fs = marginals(post_u(x))
aux_posterior!(qΩ, lik, y, post_fs)
S .= inv(Symmetric(inv(K) + Diagonal(only(expected_auglik_precision(lik, qΩ, y)))))
m .= S * (only(expected_auglik_potential(lik, qΩ, y)) + K \ mean(fz))
end
return m, S
end;
Now we just initialize the variational parameters
m = zeros(N)
S = Matrix{Float64}(I(N))
qΩ = init_aux_posterior(lik, N)
fz = gp(x, 1e-8);
And visualize the current posterior
x_te = -10:0.01:10
plot!(
plt, x_te, u_posterior(fz, m, S); color=:blue, alpha=0.3, label="Initial VI Posterior"
)
We run CAVI for 3-4 iterations
cavi!(fz, x, y, m, S, qΩ; niter=4);
And visualize the obtained variational posterior
plot!(
plt,
x_te,
u_posterior(fz, m, S);
color=:darkgreen,
alpha=0.3,
label="Final VI Posterior",
)
ELBO
How can one compute the Augmented ELBO? Again AugmentedGPLikelihoods provides helper functions to not have to compute everything yourself
function aug_elbo(lik, u_post, x, y)
qf = marginals(u_post(x))
qΩ = aux_posterior(lik, y, qf)
return expected_logtilt(lik, qΩ, y, qf) - aux_kldivergence(lik, qΩ, y) -
kldivergence(u_post.approx.q, u_post.approx.fz) # approx.fz is the prior and approx.q is the posterior
end
aug_elbo(lik, u_posterior(fz, m, S), x, y)
-347.29660251565616
Gibbs Sampling
We create our Gibbs sampling algorithm (we could do something fancier with AbstractMCMC)
function gibbs_sample(fz, f, Ω; nsamples=200)
K = ApproximateGPs._chol_cov(fz)
Σ = zeros(length(f), length(f))
μ = zeros(length(f))
return map(1:nsamples) do _
aux_sample!(Ω, lik, y, f)
Σ .= inv(Symmetric(inv(K) + Diagonal(only(auglik_precision(lik, Ω, y)))))
μ .= Σ * (only(auglik_potential(lik, Ω, y)) + K \ mean(fz))
rand!(MvNormal(μ, Σ), f)
return copy(f)
end
end;
We initialize our random variables
f = randn(N)
Ω = init_aux_variables(lik, N);
Run the sampling for default number of iterations (200)
fs = gibbs_sample(fz, f, Ω);
And visualize the samples overlapped to the variational posterior that we found earlier.
for f in fs
plot!(plt, x, f; color=:black, alpha=0.07, label="")
end
plt
Package and system information
Package information (click to expand)
Status `~/work/AugmentedGPLikelihoods.jl/AugmentedGPLikelihoods.jl/examples/studentt/Project.toml` [99985d1d] AbstractGPs v0.5.19 [298c2ebc] ApproximateGPs v0.4.5 [4689c64d] AugmentedGPLikelihoods v0.4.18 `/home/runner/work/AugmentedGPLikelihoods.jl/AugmentedGPLikelihoods.jl#727b50a` [31c24e10] Distributions v0.25.102 [98b081ad] Literate v2.15.0 [91a5bcdd] Plots v1.39.0To reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.9.3 Commit bed2cd540a1 (2023-08-24 14:43 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 2 × Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-14.0.6 (ORCJIT, skylake-avx512) Threads: 1 on 2 virtual cores Environment: JULIA_IMAGE_THREADS = 1 JULIA_PKG_SERVER_REGISTRY_PREFERENCE = eager JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src
This page was generated using Literate.jl.