Regression with Student-t noise
You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.
We load all the necessary packages
using AbstractGPs
using ApproximateGPs
using AugmentedGPLikelihoods
using Distributions
using LinearAlgebra
Plotting libraries
using Plots
We create some random data (sorted for plotting reasons)
N = 100
x = range(-10, 10; length=N)
kernel = with_lengthscale(SqExponentialKernel(), 2.0)
gp = GP(kernel)
ν = 3.5
σ = 2.0
lik = StudentTLikelihood(ν, σ)
lf = LatentGP(gp, lik, 1e-6)
f, y = rand(lf(x));
We plot the sampled data
plt = scatter(x, y; label="Data")
plot!(plt, x, f; color=:red, label="Latent GP")
CAVI Updates
We write our CAVI algorithmm
function u_posterior(fz, m, S)
return posterior(SparseVariationalApproximation(Centered(), fz, MvNormal(m, S)))
end
function cavi!(fz::AbstractGPs.FiniteGP, x, y, m, S, qΩ; niter=10)
K = ApproximateGPs._chol_cov(fz)
for _ in 1:niter
post_u = u_posterior(fz, m, S)
post_fs = marginals(post_u(x))
aux_posterior!(qΩ, lik, y, post_fs)
S .= inv(Symmetric(inv(K) + Diagonal(only(expected_auglik_precision(lik, qΩ, y)))))
m .= S * (only(expected_auglik_potential(lik, qΩ, y)) + K \ mean(fz))
end
return m, S
end;
Now we just initialize the variational parameters
m = zeros(N)
S = Matrix{Float64}(I(N))
qΩ = init_aux_posterior(lik, N)
fz = gp(x, 1e-8);
And visualize the current posterior
x_te = -10:0.01:10
plot!(
plt, x_te, u_posterior(fz, m, S); color=:blue, alpha=0.3, label="Initial VI Posterior"
)
We run CAVI for 3-4 iterations
cavi!(fz, x, y, m, S, qΩ; niter=4);
And visualize the obtained variational posterior
plot!(
plt,
x_te,
u_posterior(fz, m, S);
color=:darkgreen,
alpha=0.3,
label="Final VI Posterior",
)
ELBO
How can one compute the Augmented ELBO? Again AugmentedGPLikelihoods provides helper functions to not have to compute everything yourself
function aug_elbo(lik, u_post, x, y)
qf = marginals(u_post(x))
qΩ = aux_posterior(lik, y, qf)
return expected_logtilt(lik, qΩ, y, qf) - aux_kldivergence(lik, qΩ, y) -
kldivergence(u_post.approx.q, u_post.approx.fz) # approx.fz is the prior and approx.q is the posterior
end
aug_elbo(lik, u_posterior(fz, m, S), x, y)
-332.3931458838723
Gibbs Sampling
We create our Gibbs sampling algorithm (we could do something fancier with AbstractMCMC)
function gibbs_sample(fz, f, Ω; nsamples=200)
K = ApproximateGPs._chol_cov(fz)
Σ = zeros(length(f), length(f))
μ = zeros(length(f))
return map(1:nsamples) do _
aux_sample!(Ω, lik, y, f)
Σ .= inv(Symmetric(inv(K) + Diagonal(only(auglik_precision(lik, Ω, y)))))
μ .= Σ * (only(auglik_potential(lik, Ω, y)) + K \ mean(fz))
rand!(MvNormal(μ, Σ), f)
return copy(f)
end
end;
We initialize our random variables
f = randn(N)
Ω = init_aux_variables(lik, N);
Run the sampling for default number of iterations (200)
fs = gibbs_sample(fz, f, Ω);
And visualize the samples overlapped to the variational posterior that we found earlier.
for f in fs
plot!(plt, x, f; color=:black, alpha=0.07, label="")
end
plt
Package and system information
Package information (click to expand)
Status `~/work/AugmentedGPLikelihoods.jl/AugmentedGPLikelihoods.jl/examples/studentt/Project.toml` [99985d1d] AbstractGPs v0.5.21 [298c2ebc] ApproximateGPs v0.4.5 [4689c64d] AugmentedGPLikelihoods v0.4.18 `/home/runner/work/AugmentedGPLikelihoods.jl/AugmentedGPLikelihoods.jl#main` [31c24e10] Distributions v0.25.109 [98b081ad] Literate v2.19.0 [91a5bcdd] Plots v1.40.5To reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.10.4 Commit 48d4fd48430 (2024-06-04 10:41 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 4 × AMD EPYC 7763 64-Core Processor WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, znver3) Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores) Environment: JULIA_PKG_SERVER_REGISTRY_PREFERENCE = eager JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src
This page was generated using Literate.jl.