Train Kernel Parameters
You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.
Here we show a few ways to train (optimize) the kernel (hyper)parameters at the example of kernel-based regression using KernelFunctions.jl. All options are functionally identical, but differ a little in readability, dependencies, and computational cost.
We load KernelFunctions and some other packages. Note that while we use Zygote
for automatic differentiation and Flux.optimise
for optimization, you should be able to replace them with your favourite autodiff framework or optimizer.
using KernelFunctions
using LinearAlgebra
using Distributions
using Plots
using BenchmarkTools
using Flux
using Flux: Optimise
using Zygote
using Random: seed!
seed!(42);
Data Generation
We generate a toy dataset in 1 dimension:
xmin, xmax = -3, 3 # Bounds of the data
N = 50 # Number of samples
x_train = rand(Uniform(xmin, xmax), N) # sample the inputs
σ = 0.1
y_train = sinc.(x_train) + randn(N) * σ # evaluate a function and add some noise
x_test = range(xmin - 0.1, xmax + 0.1; length=300)
Plot the data
scatter(x_train, y_train; label="data")
plot!(x_test, sinc; label="true function")
Manual Approach
The first option is to rebuild the parametrized kernel from a vector of parameters in each evaluation of the cost function. This is similar to the approach taken in Stheno.jl.
To train the kernel parameters via Zygote.jl, we need to create a function creating a kernel from an array. A simple way to ensure that the kernel parameters are positive is to optimize over the logarithm of the parameters.
function kernel_creator(θ)
return (exp(θ[1]) * SqExponentialKernel() + exp(θ[2]) * Matern32Kernel()) ∘
ScaleTransform(exp(θ[3]))
end
From theory we know the prediction for a test set x given the kernel parameters and normalization constant:
function f(x, x_train, y_train, θ)
k = kernel_creator(θ[1:3])
return kernelmatrix(k, x, x_train) *
((kernelmatrix(k, x_train) + exp(θ[4]) * I) \ y_train)
end
Let's look at our prediction. With starting parameters p0
(picked so we get the right local minimum for demonstration) we get:
p0 = [1.1, 0.1, 0.01, 0.001]
θ = log.(p0)
ŷ = f(x_test, x_train, y_train, θ)
scatter(x_train, y_train; label="data")
plot!(x_test, sinc; label="true function")
plot!(x_test, ŷ; label="prediction")
We define the following loss:
function loss(θ)
ŷ = f(x_train, x_train, y_train, θ)
return norm(y_train - ŷ) + exp(θ[4]) * norm(ŷ)
end
The loss with our starting point:
loss(θ)
2.613933959118708
Computational cost for one step:
@benchmark let
θ = log.(p0)
opt = Optimise.ADAGrad(0.5)
grads = only((Zygote.gradient(loss, θ)))
Optimise.update!(opt, θ, grads)
end
BenchmarkTools.Trial: 6244 samples with 1 evaluation.
Range (min … max): 648.911 μs … 6.297 ms ┊ GC (min … max): 0.00% … 78.21%
Time (median): 731.971 μs ┊ GC (median): 0.00%
Time (mean ± σ): 796.419 μs ± 247.882 μs ┊ GC (mean ± σ): 5.88% ± 11.37%
▅█▆▅▄▃▁ ▁▁▁ ▁
▇▇████████▇▄▄▅▄▃▃▃▄▃▃▁▃▁▃▁▁▃▁▃▄▃▁▄▅▃▁▃▃▁▁▄▁▃▃▁▃▁▁▁▃▄▆▇█████▇▇ █
649 μs Histogram: log(frequency) by time 1.69 ms <
Memory estimate: 2.98 MiB, allocs estimate: 1563.
Training the model
Setting an initial value and initializing the optimizer:
θ = log.(p0) # Initial vector
opt = Optimise.ADAGrad(0.5)
Optimize
anim = Animation()
for i in 1:15
grads = only((Zygote.gradient(loss, θ)))
Optimise.update!(opt, θ, grads)
scatter(
x_train, y_train; lab="data", title="i = $(i), Loss = $(round(loss(θ), digits = 4))"
)
plot!(x_test, sinc; lab="true function")
plot!(x_test, f(x_test, x_train, y_train, θ); lab="Prediction", lw=3.0)
frame(anim)
end
gif(anim, "train-kernel-param.gif"; show_msg=false, fps=15);
Final loss
loss(θ)
0.5241118228076058
Using ParameterHandling.jl
Alternatively, we can use the ParameterHandling.jl package to handle the requirement that all kernel parameters should be positive. The package also allows arbitrarily nesting named tuples that make the parameters more human readable, without having to remember their position in a flat vector.
using ParameterHandling
raw_initial_θ = (
k1=positive(1.1), k2=positive(0.1), k3=positive(0.01), noise_var=positive(0.001)
)
flat_θ, unflatten = ParameterHandling.value_flatten(raw_initial_θ)
4-element Vector{Float64}:
0.09531016625781467
-2.3025852420056685
-4.6051716761053205
-6.907770180254354
We define a few relevant functions and note that compared to the previous kernel_creator
function, we do not need explicit exp
s.
function kernel_creator(θ)
return (θ.k1 * SqExponentialKernel() + θ.k2 * Matern32Kernel()) ∘ ScaleTransform(θ.k3)
end
function f(x, x_train, y_train, θ)
k = kernel_creator(θ)
return kernelmatrix(k, x, x_train) *
((kernelmatrix(k, x_train) + θ.noise_var * I) \ y_train)
end
function loss(θ)
ŷ = f(x_train, x_train, y_train, θ)
return norm(y_train - ŷ) + θ.noise_var * norm(ŷ)
end
initial_θ = ParameterHandling.value(raw_initial_θ)
The loss at the initial parameter values:
(loss ∘ unflatten)(flat_θ)
2.613933959118708
Cost per step
@benchmark let
θ = flat_θ[:]
opt = Optimise.ADAGrad(0.5)
grads = (Zygote.gradient(loss ∘ unflatten, θ))[1]
Optimise.update!(opt, θ, grads)
end
BenchmarkTools.Trial: 5410 samples with 1 evaluation.
Range (min … max): 768.474 μs … 5.099 ms ┊ GC (min … max): 0.00% … 18.36%
Time (median): 865.681 μs ┊ GC (median): 0.00%
Time (mean ± σ): 920.858 μs ± 235.392 μs ┊ GC (mean ± σ): 4.90% ± 10.59%
▄▅▅██▆▅▄▂ ▁ ▁ ▂
██████████▇▆▅▆▆▄▅▁▃▅▅▆▇██▆▅▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆███████ █
768 μs Histogram: log(frequency) by time 1.92 ms <
Memory estimate: 3.08 MiB, allocs estimate: 2228.
Training the model
Optimize
opt = Optimise.ADAGrad(0.5)
for i in 1:15
grads = (Zygote.gradient(loss ∘ unflatten, flat_θ))[1]
Optimise.update!(opt, flat_θ, grads)
end
Final loss
(loss ∘ unflatten)(flat_θ)
0.524117624126251
Flux.destructure
If we don't want to write an explicit function to construct the kernel, we can alternatively use the Flux.destructure
function. Again, we need to ensure that the parameters are positive. Note that the exp
function is now part of the loss function, instead of part of the kernel construction.
We could also use ParameterHandling.jl here. To do so, one would remove the exp
s from the loss function below and call loss ∘ unflatten
as above.
θ = [1.1, 0.1, 0.01, 0.001]
kernel = (θ[1] * SqExponentialKernel() + θ[2] * Matern32Kernel()) ∘ ScaleTransform(θ[3])
params, kernelc = Flux.destructure(kernel);
This returns the trainable params
of the kernel and a function to reconstruct the kernel.
kernelc(params)
Sum of 2 kernels:
Squared Exponential Kernel (metric = Distances.Euclidean(0.0))
- σ² = 1.1
Matern 3/2 Kernel (metric = Distances.Euclidean(0.0))
- σ² = 0.1
- Scale Transform (s = 0.01)
From theory we know the prediction for a test set x given the kernel parameters and normalization constant
function f(x, x_train, y_train, θ)
k = kernelc(θ[1:3])
return kernelmatrix(k, x, x_train) * ((kernelmatrix(k, x_train) + (θ[4]) * I) \ y_train)
end
function loss(θ)
ŷ = f(x_train, x_train, y_train, exp.(θ))
return norm(y_train - ŷ) + exp(θ[4]) * norm(ŷ)
end
Cost for one step
@benchmark let θt = θ[:], optt = Optimise.ADAGrad(0.5)
grads = only((Zygote.gradient(loss, θt)))
Optimise.update!(optt, θt, grads)
end
BenchmarkTools.Trial: 6487 samples with 1 evaluation.
Range (min … max): 653.430 μs … 2.728 ms ┊ GC (min … max): 0.00% … 58.95%
Time (median): 718.190 μs ┊ GC (median): 0.00%
Time (mean ± σ): 767.841 μs ± 215.285 μs ┊ GC (mean ± σ): 4.88% ± 10.42%
▅▅▅█▆▅▃▂ ▁▁▁ ▁
█████████▅▆▅▅▆▆▇▇██▇▇▆▃▅▅▅▄▃▅▃▃▄▃▃▄▃▄▄▃▁▃▃▁▃▁▁▁▃▃▃▁▁▁▁▄▇█████ █
653 μs Histogram: log(frequency) by time 1.76 ms <
Memory estimate: 2.98 MiB, allocs estimate: 1558.
Training the model
The loss at our initial parameter values:
θ = log.([1.1, 0.1, 0.01, 0.001]) # Initial vector
loss(θ)
2.613933959118708
Initialize optimizer
opt = Optimise.ADAGrad(0.5)
Optimize
for i in 1:15
grads = only((Zygote.gradient(loss, θ)))
Optimise.update!(opt, θ, grads)
end
Final loss
loss(θ)
0.5241118228076058
Package and system information
Package information (click to expand)
Status `~/work/KernelFunctions.jl/KernelFunctions.jl/examples/train-kernel-parameters/Project.toml` [6e4b80f9] BenchmarkTools v1.5.0 [31c24e10] Distributions v0.25.109 [587475ba] Flux v0.14.16 [f6369f11] ForwardDiff v0.10.36 [ec8451be] KernelFunctions v0.10.64 `/home/runner/work/KernelFunctions.jl/KernelFunctions.jl#master` [98b081ad] Literate v2.19.0 ⌅ [2412ca09] ParameterHandling v0.4.10 [91a5bcdd] Plots v1.40.5 [e88e6eb3] Zygote v0.6.70 [37e2e46d] LinearAlgebra Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`To reproduce this notebook's package environment, you can download the full Manifest.toml.
System information (click to expand)
Julia Version 1.10.4 Commit 48d4fd48430 (2024-06-04 10:41 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 4 × AMD EPYC 7763 64-Core Processor WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-15.0.7 (ORCJIT, znver3) Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores) Environment: JULIA_DEBUG = Documenter JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src
This page was generated using Literate.jl.