Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS

You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.


This example demonstrates how to carry out non-conjugate Gaussian process inference using the stochastic variational Gaussian process (SVGP) model. For a basic introduction to the functionality of this library, please refer to the User Guide.

Setup

using ApproximateGPs
using ParameterHandling
using Zygote
using Distributions
using LinearAlgebra
using Optim

using Plots
default(; legend=:outertopright, size=(700, 400))

using Random
Random.seed!(1234);

Generate some training data

For our binary classification model, we will use the standard approach of a latent GP with a Bernoulli likelihood. This results in a generative model that we can use to produce some training data.

First, we define the underlying latent GP

\[f \sim \mathcal{GP}(0, k(\cdot, \cdot'))\]

and sample a function f.

k_true = [30.0, 1.5]
kernel_true = k_true[1] * (SqExponentialKernel() ∘ ScaleTransform(k_true[2]))

jitter = 1e-8  # for numeric stability
lgp = LatentGP(GP(kernel_true), BernoulliLikelihood(), jitter)
x_true = 0:0.02:6
f_true, y_true = rand(lgp(x_true))

plot(x_true, f_true; seriescolor="red", label="")  # Plot the sampled function

Then, the output of this sampled function is pushed through a logistic sigmoid μ = σ(f) to constrain the output to [0, 1].

μ = mean.(lgp.lik.(f_true))
plot(x_true, μ; seriescolor="red", label="")

Finally, the outputs y of the process are sampled from a Bernoulli distribution with mean μ. We're only interested in the outputs at a subset of inputs x, so we first pick some random input locations and then find the corresponding values for y.

N = 30  # The number of training points
mask = sample(1:length(x_true), N; replace=false, ordered=true)  # Subsample some input locations
x, y = x_true[mask], y_true[mask]

scatter(x, y; label="Sampled outputs")
plot!(x_true, mean.(lgp.lik.(f_true)); seriescolor="red", label="True mean")

Creating an SVGP

Now that we have some data sampled from a generative model, we can try to recover the true generative function with an SVGP classification model.

For this, we shall use a mixture of ParameterHandling.jl to deal with our constrained parameters and Optim.jl to perform optimimisation.

The required parameters for the SVGP are - the kernel hyperparameters k, the inducing inputs z and the mean and covariance of the variational distribution q; given by m and A respectively. ParameterHandling provides an elegant way to deal with the constraints on these parameters, since k must be positive and A must be positive definite. For more details, see the ParameterHandling.jl readme.

Initialise the parameters

M = 15  # number of inducing points
raw_initial_params = (
    k=(var=positive(rand()), precision=positive(rand())),
    z=bounded.(range(0.1, 5.9; length=M), 0.0, 6.0),  # constrain z to simplify optimisation
    m=zeros(M),
    A=positive_definite(Matrix{Float64}(I, M, M)),
);

flatten takes the NamedTuple of parameters and returns a flat vector of Float64 - along with a function unflatten to reconstruct the NamedTuple from a flat vector. value takes each parameter in the NamedTuple and applies the necessary transformation to return the constrained value which can then be used to construct the SVGP model. unpack therefore takes a flat, unconstrained Vector{Float64} and returns a NamedTuple of constrained parameter values.

flat_init_params, unflatten = ParameterHandling.flatten(raw_initial_params)
unpack = ParameterHandling.value ∘ unflatten;

Now, we define a function to build everything needed for an SVGP model from the constrained parameters. The two necessary components are the LatentGP which we are trying to approximate and the SparseVariationalApproximation struct. This struct takes as arguments the inducing points fz, and the variational posterior distribution q. These elements can then be passed to the loss function (the elbo) along with the data x and y.

lik = BernoulliLikelihood()
jitter = 1e-3  # added to aid numerical stability

function build_SVGP(params::NamedTuple)
    kernel = params.k.var * (SqExponentialKernel() ∘ ScaleTransform(params.k.precision))
    f = LatentGP(GP(kernel), lik, jitter)
    q = MvNormal(params.m, params.A)
    fz = f(params.z).fx
    return SparseVariationalApproximation(fz, q), f
end

function loss(params::NamedTuple)
    svgp, f = build_SVGP(params)
    fx = f(x)
    return -elbo(svgp, fx, y)
end;

Optimise the parameters using LBFGS.

opt = optimize(
    loss ∘ unpack,
    θ -> only(Zygote.gradient(loss ∘ unpack, θ)),
    flat_init_params,
    LBFGS(;
        alphaguess=Optim.LineSearches.InitialStatic(; scaled=true),
        linesearch=Optim.LineSearches.BackTracking(),
    ),
    Optim.Options(; iterations=4_000);
    inplace=false,
)
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Final objective value:     1.204873e+01

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 1.93e-05 ≰ 0.0e+00
    |x - x'|/|x'|          = 4.28e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 8.71e-11 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 7.23e-12 ≰ 0.0e+00
    |g(x)|                 = 5.39e-05 ≰ 1.0e-08

 * Work counters
    Seconds run:   30  (vs limit Inf)
    Iterations:    4000
    f(x) calls:    4398
    ∇f(x) calls:   4001

Finally, build the optimised SVGP model, and sample some functions to see if they are close to the true function.

final_params = unpack(opt.minimizer)

svgp_opt, f_opt = build_SVGP(final_params)
post_opt = posterior(svgp_opt)
l_post_opt = LatentGP(post_opt, BernoulliLikelihood(), jitter)

post_f_samples = rand(l_post_opt.f(x_true, 1e-6), 20)
post_μ_samples = mean.(l_post_opt.lik.(post_f_samples))

plt = plot(x_true, post_μ_samples; seriescolor="red", linealpha=0.2, label="")
scatter!(plt, x, y; seriescolor="blue", label="Data points")
vline!(final_params.z; label="Pseudo-points")
plot!(
    x_true, mean.(lgp.lik.(f_true)); seriescolor="green", linewidth=3, label="True function"
)

Package and system information
Package information (click to expand)
      Status `~/work/ApproximateGPs.jl/ApproximateGPs.jl/examples/b-classification/Project.toml`
  [298c2ebc] ApproximateGPs v0.4.5 `/home/runner/work/ApproximateGPs.jl/ApproximateGPs.jl#master`
  [31c24e10] Distributions v0.25.89
  [98b081ad] Literate v2.14.0
  [429524aa] Optim v1.7.5
  [2412ca09] ParameterHandling v0.4.6
  [91a5bcdd] Plots v1.38.11
  [e88e6eb3] Zygote v0.6.60
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random
To reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.6.7
Commit 3b76b25b64 (2022-07-19 15:11 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, icelake-server)
Environment:
  JULIA_DEBUG = Documenter
  JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src

This page was generated using Literate.jl.