Convolution via Quadrature

You are seeing the HTML output generated by Documenter.jl and Literate.jl from the Julia source file. The corresponding notebook can be viewed in nbviewer.

This example implements the convolution of a GP against the function ϕ(x) = exp(-x^2), i.e.: (Lf)(x′) = ∫ ϕ(x′ - x) f(x) dx It does this approximately using Gauss-Hermite quadrature, the implementation for which is provided by FastGaussQuadrature.jl. The implementation is limited to GPs whose index set is the real line.

using AbstractGPs
using AbstractGPsMakie
using CairoMakie
using FastGaussQuadrature
using KernelFunctions
using LaTeXStrings
using LinearAlgebra
using Literate
using Random
using Stheno

import AbstractGPs: AbstractGP, mean, cov, var

using CairoMakie: RGB
using Stheno: DerivedGP

Define new affine transformation

The mean function is assumed to have zero-mean. The cross-kernel (function of x′) is ∫ ϕ(x′ - x) k(x, y) dx, and the kernel (function of x′ and y′) is ∫ ϕ(x′ - x) k(x, y) ϕ(y′ - y) dx dy.

convolve(f::AbstractGP) = DerivedGP((convolve, f), f.gpc)

const conv_args = Tuple{typeof(convolve), AbstractGP}

mean((_, f)::conv_args, x::AbstractVector{<:Real}) = zeros(length(x))
cov(args::conv_args, x::AbstractVector{<:Real}) = cov(args, x, x)
var(args::conv_args, x::AbstractVector{<:Real}) = var(args, x, x)
function var(args::conv_args, x::AbstractVector{<:Real}, x′::AbstractVector{<:Real})
    return diag(cov(args, x, x′))
end

_quadrature(f, xs, ws) = sum(map((x, w) -> w * f(x), xs, ws))

function cov((_, f)::conv_args, x::AbstractVector{<:Real}, x′::AbstractVector{<:Real})

    num_points = 15
    xs, ws = gausshermite(num_points)

    cols_of_C = map(x′) do x′n
        col_elements = map(x) do xn
            _quadrature(
                x -> _quadrature(x′ -> only(cov(f, [xn - x], [x′n - x′])), xs, ws), xs, ws
            )
        end
    end
    return reduce(hcat, cols_of_C)
end

function cov(
    (_, f)::conv_args,
    f′::AbstractGP,
    x::AbstractVector{<:Real},
    x′::AbstractVector{<:Real},
)
    num_points = 15
    xs, ws = gausshermite(num_points)

    cols_of_C = map(x′) do x′n
        col_elements = map(x) do xn
            _quadrature(x -> only(cov(f, [xn - x], [x′n])), xs, ws)
        end
    end
    return reduce(hcat, cols_of_C)
end

function cov(
    f′::AbstractGP,
    args::conv_args,
    x::AbstractVector{<:Real},
    x′::AbstractVector{<:Real},
)
    return collect(transpose(cov(args, f′, x′, x)))
end
cov (generic function with 88 methods)

Some plotting config

pt_per_unit() = 1

font_size() = 12

listing_font_size() = 10

page_width() = 6

size_from_inches(; height=4, width=4) = 72 .* (width, height)

set_theme!(font="Times")

function colours()
    return Dict(
        :blue => RGB(0/255, 107/255, 164/255),
        :cyan => RGB(75/255, 166/255, 251/255),
        :red => RGB(200/255, 82 / 255, 0 / 255),
        :pink => RGB(169/255, 90/255, 161/255),
        :black => RGB(0.0, 0.0, 0.0),
        :orange => RGB(245/255, 121/255, 58/255),
    )
end

shapes() = [:utriangle, :diamond, :square, :circle, :cross]

band_alpha() = 0.3
sample_alpha() = 0.2
point_alpha() = 1.0


function plot_band!(ax, x_plot, fx, colour, label)
    ms = marginals(fx)
    symband!(
        ax, x_plot, mean.(ms), std.(ms);
        bandscale=3,
        color=(colours()[colour], 0.5 * band_alpha()),
        label=label,
    )
end

function plot_sample!(ax, x_plot, fx, colour)
    gpsample!(
        ax, x_plot, fx;
        samples=4, color=(colours()[colour], sample_alpha()),
    )
end

function plot_gp!(ax, x_plot, fx, colour, label)
    plot_band!(ax, x_plot, fx, colour, label)
    plot_sample!(ax, x_plot, fx, colour)
end
plot_gp! (generic function with 1 method)

Plot some stuff

Build a GPPP in which one GP is a convolution of the other, using the convolve transformation defined above.

let
    f = @gppp let
        f = GP(with_lengthscale(Matern52Kernel(), 0.5))
        g = convolve(f)
    end

    rng = Xoshiro(123)
    x_f_obs = GPPPInput(:f, rand(rng, 2) .+ 1)
    x_g_obs = GPPPInput(:g, -rand(rng, 2) .- 1)
    x_obs = vcat(x_f_obs, x_g_obs)
    y = rand(Xoshiro(123), f(x_obs, 1e-3))
    y_f, y_g = split(x_obs, y)
    f_post = posterior(f(x_obs, 1e-3), y)

    x_plot = range(-5.0, 5.0; length=100)
    x_f = GPPPInput(:f, x_plot)
    x_g = GPPPInput(:g, x_plot)
    x = vcat(x_f, x_g)
    fig = Figure()

    # Plot posterior.
    ax = Axis(fig[1, 1]; xlabel=L"x")
    plot_gp!(ax, x_plot, f_post(x_f, 1e-6), :blue, "f")
    plot_gp!(ax, x_plot, f_post(x_g, 1e-6), :orange, "g")
    scatter!(ax, x_f_obs.x, y_f; color=colours()[:blue], markersize=7)
    scatter!(ax, x_g_obs.x, y_g; color=colours()[:orange], markersize=7)

    # Plot legend.
    Legend(fig[1, 2], ax; orientation=:vertical)

    fig
end


This page was generated using Literate.jl.