Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS
You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.
This example demonstrates how to carry out non-conjugate Gaussian process inference using the stochastic variational Gaussian process (SVGP) model. For a basic introduction to the functionality of this library, please refer to the User Guide.
Setup
using ApproximateGPs
using ParameterHandling
using Zygote
using Distributions
using LinearAlgebra
using Optim
using Plots
default(; legend=:outertopright, size=(700, 400))
using Random
Random.seed!(1234);
Generate some training data
For our binary classification model, we will use the standard approach of a latent GP with a Bernoulli likelihood. This results in a generative model that we can use to produce some training data.
First, we define the underlying latent GP
\[f \sim \mathcal{GP}(0, k(\cdot, \cdot'))\]
and sample a function f
.
k_true = [30.0, 1.5]
kernel_true = k_true[1] * (SqExponentialKernel() ∘ ScaleTransform(k_true[2]))
jitter = 1e-8 # for numeric stability
lgp = LatentGP(GP(kernel_true), BernoulliLikelihood(), jitter)
x_true = 0:0.02:6
f_true, y_true = rand(lgp(x_true))
plot(x_true, f_true; seriescolor="red", label="") # Plot the sampled function
Then, the output of this sampled function is pushed through a logistic sigmoid μ = σ(f)
to constrain the output to [0, 1]
.
μ = mean.(lgp.lik.(f_true))
plot(x_true, μ; seriescolor="red", label="")
Finally, the outputs y
of the process are sampled from a Bernoulli distribution with mean μ
. We're only interested in the outputs at a subset of inputs x
, so we first pick some random input locations and then find the corresponding values for y
.
N = 30 # The number of training points
mask = sample(1:length(x_true), N; replace=false, ordered=true) # Subsample some input locations
x, y = x_true[mask], y_true[mask]
scatter(x, y; label="Sampled outputs")
plot!(x_true, mean.(lgp.lik.(f_true)); seriescolor="red", label="True mean")
Creating an SVGP
Now that we have some data sampled from a generative model, we can try to recover the true generative function with an SVGP classification model.
For this, we shall use a mixture of ParameterHandling.jl to deal with our constrained parameters and Optim.jl to perform optimimisation.
The required parameters for the SVGP are - the kernel hyperparameters k
, the inducing inputs z
and the mean and covariance of the variational distribution q
; given by m
and A
respectively. ParameterHandling provides an elegant way to deal with the constraints on these parameters, since k
must be positive and A
must be positive definite. For more details, see the ParameterHandling.jl readme.
Initialise the parameters
M = 15 # number of inducing points
raw_initial_params = (
k=(var=positive(rand()), precision=positive(rand())),
z=bounded.(range(0.1, 5.9; length=M), 0.0, 6.0), # constrain z to simplify optimisation
m=zeros(M),
A=positive_definite(Matrix{Float64}(I, M, M)),
);
flatten
takes the NamedTuple
of parameters and returns a flat vector of Float64
- along with a function unflatten
to reconstruct the NamedTuple
from a flat vector. value
takes each parameter in the NamedTuple
and applies the necessary transformation to return the constrained value which can then be used to construct the SVGP model. unpack
therefore takes a flat, unconstrained Vector{Float64}
and returns a NamedTuple
of constrained parameter values.
flat_init_params, unflatten = ParameterHandling.flatten(raw_initial_params)
unpack = ParameterHandling.value ∘ unflatten;
Now, we define a function to build everything needed for an SVGP model from the constrained parameters. The two necessary components are the LatentGP
which we are trying to approximate and the SparseVariationalApproximation
struct. This struct takes as arguments the inducing points fz
, and the variational posterior distribution q
. These elements can then be passed to the loss function (the elbo
) along with the data x
and y
.
lik = BernoulliLikelihood()
jitter = 1e-3 # added to aid numerical stability
function build_SVGP(params::NamedTuple)
kernel = params.k.var * (SqExponentialKernel() ∘ ScaleTransform(params.k.precision))
f = LatentGP(GP(kernel), lik, jitter)
q = MvNormal(params.m, params.A)
fz = f(params.z).fx
return SparseVariationalApproximation(fz, q), f
end
function loss(params::NamedTuple)
svgp, f = build_SVGP(params)
fx = f(x)
return -elbo(svgp, fx, y)
end;
Optimise the parameters using LBFGS.
opt = optimize(
loss ∘ unpack,
θ -> only(Zygote.gradient(loss ∘ unpack, θ)),
flat_init_params,
LBFGS(;
alphaguess=Optim.LineSearches.InitialStatic(; scaled=true),
linesearch=Optim.LineSearches.BackTracking(),
),
Optim.Options(; iterations=4_000);
inplace=false,
)
* Status: failure (reached maximum number of iterations)
* Candidate solution
Final objective value: 1.204873e+01
* Found with
Algorithm: L-BFGS
* Convergence measures
|x - x'| = 1.93e-05 ≰ 0.0e+00
|x - x'|/|x'| = 4.28e-06 ≰ 0.0e+00
|f(x) - f(x')| = 8.71e-11 ≰ 0.0e+00
|f(x) - f(x')|/|f(x')| = 7.23e-12 ≰ 0.0e+00
|g(x)| = 5.39e-05 ≰ 1.0e-08
* Work counters
Seconds run: 30 (vs limit Inf)
Iterations: 4000
f(x) calls: 4398
∇f(x) calls: 4001
Finally, build the optimised SVGP model, and sample some functions to see if they are close to the true function.
final_params = unpack(opt.minimizer)
svgp_opt, f_opt = build_SVGP(final_params)
post_opt = posterior(svgp_opt)
l_post_opt = LatentGP(post_opt, BernoulliLikelihood(), jitter)
post_f_samples = rand(l_post_opt.f(x_true, 1e-6), 20)
post_μ_samples = mean.(l_post_opt.lik.(post_f_samples))
plt = plot(x_true, post_μ_samples; seriescolor="red", linealpha=0.2, label="")
scatter!(plt, x, y; seriescolor="blue", label="Data points")
vline!(final_params.z; label="Pseudo-points")
plot!(
x_true, mean.(lgp.lik.(f_true)); seriescolor="green", linewidth=3, label="True function"
)
Package and system information
Package information (click to expand)
Status `~/work/ApproximateGPs.jl/ApproximateGPs.jl/examples/b-classification/Project.toml` [298c2ebc] ApproximateGPs v0.4.5 `/home/runner/work/ApproximateGPs.jl/ApproximateGPs.jl#master` [31c24e10] Distributions v0.25.89 [98b081ad] Literate v2.14.0 [429524aa] Optim v1.7.5 [2412ca09] ParameterHandling v0.4.6 [91a5bcdd] Plots v1.38.11 [e88e6eb3] Zygote v0.6.60 [37e2e46d] LinearAlgebra [9a3f8284] RandomTo reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.6.7 Commit 3b76b25b64 (2022-07-19 15:11 UTC) Platform Info: OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-11.0.1 (ORCJIT, icelake-server) Environment: JULIA_DEBUG = Documenter JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src
This page was generated using Literate.jl.