I’ve been reimplementing some of my PyTorch code in julia as a way to learn the language and I came across some behaviour I don’t understand when using Flux for doing gradient based optimization on a CPU. In this practice I was implementing type-II maximum likelihood for GP-regression (so low-dimensional optimization with large linear algebra operations for each loss function evaluation).
If I set the number of OpenBLAS threads to maximize CPU usage (through BLAS.set_num_threads(12) in my case ) my julia program expectedly uses all my computer’s resources while optimizing for a fixed number of iterations.
If I set the number of OpenBLAS threads to 1 my julia program expectedly uses a single CPU core but completes the job in about 80% of the time. (Note that this is about 15% slower than my PyTorch implementation which runs on all CPU cores).
This makes me think there’s something I’ve misunderstood how to best use julia + Flux. Are there some bottleknecks in the OpenBLAS julia interface that are being actively worked on?
I’ve included my code below if someone thinks this might be the result of a newb mistake (which it very well might be!)
using Flux
using LinearAlgebra
using Plots
"""
Computes the distance between each coordinate for each point in x
Returns a matrix whose dimension is n+1 where n is the dimension of x
Δ(x::AbstractArray{T,2}) -> AbstractArray{T,3}
Δ(x::AbstractArray{T,1}} -> AbstractArray{T,2}
"""
function Δ(x1::AbstractArray{T,2}, x2::AbstractArray{T,2}) where T
s1, s2 = size(x1)
t1, t2 = size(x2)
return reshape(x1, (s1, s2, 1)) .- reshape(x2, (t1, 1, t2))
end
function Δ(x1::AbstractArray{T,1}, x2::AbstractArray{T,1}) where T
n1 = length(x1)
n2 = length(x2)
return reshape(Δ(reshape(x1, 1, n1), reshape(x2, 1, n2)), (n1, n2))
end
function Δ(x::AbstractArray)
return Δ(x, x)
end
"""
Squared exponential kernel with length scale ℓ and σf = exp(logσf)
"""
struct sekernel{T <: AbstractArray}
ℓ::T
logσf::T
end
Flux.@functor sekernel
sekernel(ℓ::AbstractFloat, σf::AbstractFloat) = sekernel([ℓ], [σf])
function (k::sekernel)(x1::AbstractArray{T,2}, x2::AbstractArray{T,2}) where T
out = exp.(k.logσf).^2 .* exp.(-sum(Δ(x1, x2).^2, dims=1) ./ (2 * k.ℓ.^2))
return out[1,:,:]
end
function (k::sekernel)(x1::AbstractArray{T,1}, x2::AbstractArray{T,1}) where T
return k(reshape(x1, 1, length(x1)), reshape(x2, 1, length(x2)))
end
function (k::sekernel)(x::AbstractArray)
return k(x, x)
end
"""
Gaussian process container
x : training inputs, f : training targets, k : kernel struct, logσn : i.i.d noise
"""
struct GP{U,S <: AbstractArray,T <: AbstractArray,I <: AbstractArray}
x::S
f::T
k::U
logσn::I
end
Flux.@functor GP
function GP(x::AbstractArray, f::AbstractArray, logσn::AbstractFloat, k)
return GP(x, f, k, [logσn])
end
Flux.trainable(g::GP) = (g.logσn, g.k)
function (g::GP)(xs::AbstractArray)
μ, Σ = infer(g.f, g.k(g.x) + I * exp.(g.logσn[1]), g.k(g.x, xs), g.k(xs))
return (μ, Σ)
end
"""
Function for performing inference with a GP
"""
function infer(f::AbstractArray{T}, # check this (maybe just AbstractArray{T}?)
K11::AbstractArray{T,2},
K12::AbstractArray{T,2},
K22::AbstractArray{T,2}) where T
C = cholesky(K11 + I * 1e-12) # adding 1e-12 for numerical stability
K21 = transpose(K12)
μ = K21 * (C \ f)
Σ = K22 - K21 * (C \ K12)
return (μ, Σ)
end
"""
Computes the log_mll for a Gaussian process
"""
function log_mll(f::AbstractArray{T}, K11::AbstractArray{T,2}) where T
C = cholesky(K11)
-0.5 * f' * (C \ f) - 0.5 * logdet(C) - length(f) / 2 * log(2 * π)
end
function log_mll(g::GP)
return log_mll(g.f, g.k(g.x) + I * exp.(g.logσn[1]))
end
"""
type-ii maximum likelihood training of GP
"""
function train_gp!(model::GP, epochs::Integer, lr=1e-1)
opt = Flux.Optimise.ADAM(lr)
ps = Flux.params(model)
for i in 1:epochs
gs = gradient(ps) do
return -log_mll(model)
end
Flux.Optimise.update!(opt, ps, gs)
if i % 1000 == 0
@info "Epoch $i | len $(model.k.ℓ) | σf $(exp.(model.k.logσf)) | σn $(exp.(model.logσn)) | lml $(log_mll(model))"
end
end
end
function main()
x = collect(LinRange(0., 10., 100))
y_true = x + sin.(x) * 5 .- 10.0
y = y_true + randn(length(x)) * 2.0
model = GP(x, y, -4.0, sekernel(0.1, -4.0))
train_gp!(model, 5000)
xt = collect(LinRange(-0., 10., 200))
μ, Σ = model(xt)
σ = 2 * sqrt.(diag(Σ))
p = scatter(x, y, label="Measurements", alpha=0.5, xlabel="x", ylabel="y")
p = plot!(x, y_true, linestyle=:dash, label="Generating Func.", linewidth=2)
p = plot!(xt, μ, ribbon=σ, label="Exact GP", linewidth=2)
end