Train Kernel Parameters

You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.


Here we show a few ways to train (optimize) the kernel (hyper)parameters at the example of kernel-based regression using KernelFunctions.jl. All options are functionally identical, but differ a little in readability, dependencies, and computational cost.

We load KernelFunctions and some other packages. Note that while we use Zygote for automatic differentiation and Flux.optimise for optimization, you should be able to replace them with your favourite autodiff framework or optimizer.

using KernelFunctions
using LinearAlgebra
using Distributions
using Plots
using BenchmarkTools
using Flux
using Flux: Optimise
using Zygote
using Random: seed!
seed!(42);

Data Generation

We generate a toy dataset in 1 dimension:

xmin, xmax = -3, 3  # Bounds of the data
N = 50 # Number of samples
x_train = rand(Uniform(xmin, xmax), N)  # sample the inputs
σ = 0.1
y_train = sinc.(x_train) + randn(N) * σ  # evaluate a function and add some noise
x_test = range(xmin - 0.1, xmax + 0.1; length=300)

Plot the data

scatter(x_train, y_train; label="data")
plot!(x_test, sinc; label="true function")

Manual Approach

The first option is to rebuild the parametrized kernel from a vector of parameters in each evaluation of the cost function. This is similar to the approach taken in Stheno.jl.

To train the kernel parameters via Zygote.jl, we need to create a function creating a kernel from an array. A simple way to ensure that the kernel parameters are positive is to optimize over the logarithm of the parameters.

function kernel_creator(θ)
    return (exp(θ[1]) * SqExponentialKernel() + exp(θ[2]) * Matern32Kernel()) ∘
           ScaleTransform(exp(θ[3]))
end

From theory we know the prediction for a test set x given the kernel parameters and normalization constant:

function f(x, x_train, y_train, θ)
    k = kernel_creator(θ[1:3])
    return kernelmatrix(k, x, x_train) *
           ((kernelmatrix(k, x_train) + exp(θ[4]) * I) \ y_train)
end

Let's look at our prediction. With starting parameters p0 (picked so we get the right local minimum for demonstration) we get:

p0 = [1.1, 0.1, 0.01, 0.001]
θ = log.(p0)
ŷ = f(x_test, x_train, y_train, θ)
scatter(x_train, y_train; label="data")
plot!(x_test, sinc; label="true function")
plot!(x_test, ŷ; label="prediction")

We define the following loss:

function loss(θ)
    ŷ = f(x_train, x_train, y_train, θ)
    return norm(y_train - ŷ) + exp(θ[4]) * norm(ŷ)
end

The loss with our starting point:

loss(θ)
2.613933959118708

Computational cost for one step:

@benchmark let
    θ = log.(p0)
    opt = Optimise.ADAGrad(0.5)
    grads = only((Zygote.gradient(loss, θ)))
    Optimise.update!(opt, θ, grads)
end
BenchmarkTools.Trial: 6129 samples with 1 evaluation.
 Range (min … max):  694.496 μs …   8.422 ms  ┊ GC (min … max): 0.00% … 58.83%
 Time  (median):     750.151 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   811.622 μs ± 263.166 μs  ┊ GC (mean ± σ):  5.62% ± 11.05%

   ▄██▆▅▄▃▁                                            ▁▁▁▁▁    ▂
  █████████▆▇▅▄▅▄▃▁▁▁▃▁▃▁▁▁▃▅▅▅▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▇█▇██████▇▇ █
  694 μs        Histogram: log(frequency) by time       1.69 ms <

 Memory estimate: 2.98 MiB, allocs estimate: 1563.

Training the model

Setting an initial value and initializing the optimizer:

θ = log.(p0) # Initial vector
opt = Optimise.ADAGrad(0.5)

Optimize

anim = Animation()
for i in 1:15
    grads = only((Zygote.gradient(loss, θ)))
    Optimise.update!(opt, θ, grads)
    scatter(
        x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
    )
    plot!(x_test, sinc; lab="true function")
    plot!(x_test, f(x_test, x_train, y_train, θ); lab="Prediction", lw=3.0)
    frame(anim)
end
gif(anim, "train-kernel-param.gif"; show_msg=false, fps=15);

Final loss

loss(θ)
0.5241118228076058

Using ParameterHandling.jl

Alternatively, we can use the ParameterHandling.jl package to handle the requirement that all kernel parameters should be positive. The package also allows arbitrarily nesting named tuples that make the parameters more human readable, without having to remember their position in a flat vector.

using ParameterHandling

raw_initial_θ = (
    k1=positive(1.1), k2=positive(0.1), k3=positive(0.01), noise_var=positive(0.001)
)

flat_θ, unflatten = ParameterHandling.value_flatten(raw_initial_θ)
4-element Vector{Float64}:
  0.09531016625781467
 -2.3025852420056685
 -4.6051716761053205
 -6.907770180254354

We define a few relevant functions and note that compared to the previous kernel_creator function, we do not need explicit exps.

function kernel_creator(θ)
    return (θ.k1 * SqExponentialKernel() + θ.k2 * Matern32Kernel()) ∘ ScaleTransform(θ.k3)
end

function f(x, x_train, y_train, θ)
    k = kernel_creator(θ)
    return kernelmatrix(k, x, x_train) *
           ((kernelmatrix(k, x_train) + θ.noise_var * I) \ y_train)
end

function loss(θ)
    ŷ = f(x_train, x_train, y_train, θ)
    return norm(y_train - ŷ) + θ.noise_var * norm(ŷ)
end

initial_θ = ParameterHandling.value(raw_initial_θ)

The loss at the initial parameter values:

(loss ∘ unflatten)(flat_θ)
2.613933959118708

Cost per step

@benchmark let
    θ = flat_θ[:]
    opt = Optimise.ADAGrad(0.5)
    grads = (Zygote.gradient(loss ∘ unflatten, θ))[1]
    Optimise.update!(opt, θ, grads)
end
BenchmarkTools.Trial: 5247 samples with 1 evaluation.
 Range (min … max):  822.244 μs …   4.932 ms  ┊ GC (min … max): 0.00% … 19.87%
 Time  (median):     898.757 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   949.880 μs ± 228.438 μs  ┊ GC (mean ± σ):  4.76% ± 10.50%

  ▄▄▄██▆▅▃▁                                                ▁▁ ▁ ▁
  ██████████▅▇▆▅▅▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▅▆██████ █
  822 μs        Histogram: log(frequency) by time       1.94 ms <

 Memory estimate: 3.08 MiB, allocs estimate: 2228.

Training the model

Optimize

opt = Optimise.ADAGrad(0.5)
for i in 1:15
    grads = (Zygote.gradient(loss ∘ unflatten, flat_θ))[1]
    Optimise.update!(opt, flat_θ, grads)
end

Final loss

(loss ∘ unflatten)(flat_θ)
0.524117624126251

Flux.destructure

If we don't want to write an explicit function to construct the kernel, we can alternatively use the Flux.destructure function. Again, we need to ensure that the parameters are positive. Note that the exp function is now part of the loss function, instead of part of the kernel construction.

We could also use ParameterHandling.jl here. To do so, one would remove the exps from the loss function below and call loss ∘ unflatten as above.

θ = [1.1, 0.1, 0.01, 0.001]

kernel = (θ[1] * SqExponentialKernel() + θ[2] * Matern32Kernel()) ∘ ScaleTransform(θ[3])

params, kernelc = Flux.destructure(kernel);

This returns the trainable params of the kernel and a function to reconstruct the kernel.

kernelc(params)
Sum of 2 kernels:
	Squared Exponential Kernel (metric = Distances.Euclidean(0.0))
			- σ² = 1.1
	Matern 3/2 Kernel (metric = Distances.Euclidean(0.0))
			- σ² = 0.1
	- Scale Transform (s = 0.01)

From theory we know the prediction for a test set x given the kernel parameters and normalization constant

function f(x, x_train, y_train, θ)
    k = kernelc(θ[1:3])
    return kernelmatrix(k, x, x_train) * ((kernelmatrix(k, x_train) + (θ[4]) * I) \ y_train)
end

function loss(θ)
    ŷ = f(x_train, x_train, y_train, exp.(θ))
    return norm(y_train - ŷ) + exp(θ[4]) * norm(ŷ)
end

Cost for one step

@benchmark let θt = θ[:], optt = Optimise.ADAGrad(0.5)
    grads = only((Zygote.gradient(loss, θt)))
    Optimise.update!(optt, θt, grads)
end
BenchmarkTools.Trial: 6309 samples with 1 evaluation.
 Range (min … max):  679.158 μs …   3.291 ms  ┊ GC (min … max): 0.00% … 28.18%
 Time  (median):     742.155 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   789.581 μs ± 218.341 μs  ┊ GC (mean ± σ):  4.97% ± 10.46%

  ▂▅██▆▅▃▁                                                  ▁   ▂
  ████████▇▅▆▅▅▁▅▁▁▄▃▁▁▁▁▁▁▄▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▇██████ █
  679 μs        Histogram: log(frequency) by time       1.86 ms <

 Memory estimate: 2.98 MiB, allocs estimate: 1558.

Training the model

The loss at our initial parameter values:

θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
loss(θ)
2.613933959118708

Initialize optimizer

opt = Optimise.ADAGrad(0.5)

Optimize

for i in 1:15
    grads = only((Zygote.gradient(loss, θ)))
    Optimise.update!(opt, θ, grads)
end

Final loss

loss(θ)
0.5241118228076058

Package and system information
Package information (click to expand)
Status `~/work/KernelFunctions.jl/KernelFunctions.jl/examples/train-kernel-parameters/Project.toml`
  [6e4b80f9] BenchmarkTools v1.5.0
  [31c24e10] Distributions v0.25.109
  [587475ba] Flux v0.14.16
  [f6369f11] ForwardDiff v0.10.36
  [ec8451be] KernelFunctions v0.10.64 `/home/runner/work/KernelFunctions.jl/KernelFunctions.jl#ab866b9`
  [98b081ad] Literate v2.19.0
⌅ [2412ca09] ParameterHandling v0.4.10
  [91a5bcdd] Plots v1.40.5
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`
To reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Documenter
  JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src

This page was generated using Literate.jl.