As a learning project I reimplemented a simple Bayesian NN in julia from scratch (hopefully without mistakes!)
Benchmarking vs. my Pytorch code on the same CPU I’m seeing a runtime of +30% (using BenchmarkTools)…
While I’ve tried to follow the performance-tips as best as I could, this difference makes me think I’ve missed something major.
Hopefully someone can see at a glance something stupid I’ve done because after profiling, benchmarking , and tutorials I’ve hit a wall.
using Flux
using Flux: glorot_uniform
using Zygote
using Flux.Optimise: update!, ADAM
using Plots
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,3}) where T = batched_mul(x, y)
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,2}) where T = batched_mul_loop(x, y)
"""
Multiplies a batch of matrices with a
"""
function batched_mul_loop(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T
B = [view(A, :, :, k) * x for k in axes(A, 3)]
# B = [Ai * x for Ai in eachslice(A, dims=3)]
return cat(B..., dims=3)
end
function reparamSample(μ::AbstractArray, logσ::AbstractArray, m::Integer)
r = randn(Float32, size(μ)..., m)
return @. μ + r * exp(logσ)
end
"""
Linear layer for Bayes neural network
"""
mutable struct VariationalLinear{F <: Function,S <: AbstractArray,T <: AbstractArray,U <: AbstractArray,V <: AbstractArray,N <: Integer}
in::N
out::N
W::S
logσW::S
b::T
logσb::T
σ::F
W_sample::U
b_sample::V
η::N
end
function VariationalLinear(in::Integer, out::Integer, η::Integer, σ=identity)
W = glorot_uniform(out, in)
logσW = -6.0f0 * ones(Float32, out, in)
b = zeros(Float32, out)
logσb = -6.0f0 * ones(Float32, out)
return VariationalLinear(in, out, W, logσW, b, logσb, σ,
reparamSample(W, logσW, η),
reparamSample(b, logσb, η),η)
end
function update_samples(a::VariationalLinear)
a.W_sample = reparamSample(a.W, a.logσW, a.η)
a.b_sample = reparamSample(a.b, a.logσb, a.η)
end
function (a::VariationalLinear)(x::AbstractArray{T}) where T
update_samples(a)
out = a.σ.(a.W_sample * x .+ reshape(a.b_sample, a.out, 1, a.η))
return out
end
"""
BayesNN
"""
struct BayesNN{T <: Array{VariationalLinear,1},P <: Zygote.Params, N<:Integer}
layers::T
θ::P
η::N
end
function BayesNN(in::Integer, out::Integer, η::Integer, num_layers::Integer, num_hidden::Integer, σ=relu)
# putting layers into array
layers = VariationalLinear{<:Function}[VariationalLinear(in, num_hidden, η, σ),]
append!(layers, [VariationalLinear(num_hidden, num_hidden, η, σ) for i in 1:(num_layers - 1)])
append!(layers, [VariationalLinear(num_hidden, out, η),])
# collecting into parameter array
P = [layers[1].W, layers[1].b, layers[1].logσW, layers[1].logσb]
(L -> append!(P, [L.W, L.b, L.logσW, L.logσb])).(layers[2:end])
return BayesNN(layers, Flux.params(P), η)
end
(b::BayesNN)(x) = foldl((x, b) -> b(x), b.layers, init=x)
function log_likelihood(x, y, noise, model)
return sum(-0.5f0 / noise^2 .* (model(x) .- y).^2) / model.η
end
function log_normal(x, μ, logσ)
-0.5f0 * (x .- μ).^2 ./ exp.(logσ).^2 .- logσ .- log(sqrt(2.0f0 * π))
end
function mean_log_prob(m)
sum(sum(m, dims=3))
end
function kl_divergence(model)
logq = 0.0f0
logp = 0.0f0
for layer in model.layers
logq += mean_log_prob(log_normal(layer.W_sample, layer.W, layer.logσW))/model.η
logp += mean_log_prob(log_normal(layer.W_sample, 0.0f0, 1.0f0))/model.η
end
return logq - logp
end
function main()
@info "Start"
model = BayesNN(1, 1, 20, 3, 100)
# making dummy-data
bs = 100;
epochs = 10000;
period = 2.0f0;
noise = 1.0f0;
x = reshape(collect(LinRange(0.0, 10.0, bs)), (1, bs));
y = x .* 2 .* sin.(x .* (2 * π / period))
yt = y .+ randn(size(x)...) * noise
# converting to float32
x = convert(Array{Float32}, x)
y = convert(Array{Float32}, y)
yt = convert(Array{Float32}, yt)
θ = model.θ
opt = Flux.Optimise.ADAM(1e-3)
for i in 1:epochs
gs = gradient(θ) do
return -log_likelihood(x, y, noise, model) + kl_divergence(model)
end
Flux.Optimise.update!(opt, θ, gs)
if i % 1000 == 0 # change 1 to higher number to compute and print less frequently
@info "Epoch $i| log-likelihood $(-log_likelihood(x, y, noise, model)) | kl-div $(kl_divergence(model))"
end
end
@info "Done!"
theme(:dark)
plot(x',yt', seriestype=:scatter, alpha=0.4)
plot!(x',y')
plot!(x',model(x)[1,:,:], legend = false , alpha=0.8 )
end