Performance issues?

As a learning project I reimplemented a simple Bayesian NN in julia from scratch (hopefully without mistakes!)
Benchmarking vs. my Pytorch code on the same CPU I’m seeing a runtime of +30% (using BenchmarkTools)…

While I’ve tried to follow the performance-tips as best as I could, this difference makes me think I’ve missed something major.

Hopefully someone can see at a glance something stupid I’ve done because after profiling, benchmarking , and tutorials I’ve hit a wall.

using Flux 
using Flux: glorot_uniform
using Zygote
using Flux.Optimise: update!, ADAM
using Plots 
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,3}) where T = batched_mul(x, y)
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,2}) where T = batched_mul_loop(x, y)

"""
Multiplies a batch of matrices with a 
"""

function batched_mul_loop(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    B = [view(A, :, :, k) * x for k in axes(A, 3)]
    # B = [Ai * x for Ai in eachslice(A, dims=3)]
    return cat(B..., dims=3)
end

function reparamSample(μ::AbstractArray, logσ::AbstractArray, m::Integer)
    r = randn(Float32, size(μ)..., m)
    return @. μ + r * exp(logσ)  
end
"""
Linear layer for Bayes neural network 
"""
mutable struct VariationalLinear{F <: Function,S <: AbstractArray,T <: AbstractArray,U <: AbstractArray,V <: AbstractArray,N <: Integer}
    in::N
    out::N
    W::S
    logσW::S
    b::T
    logσb::T
    σ::F
    W_sample::U
    b_sample::V
    η::N
end

function VariationalLinear(in::Integer, out::Integer, η::Integer, σ=identity)
    W = glorot_uniform(out, in)
    logσW = -6.0f0 * ones(Float32, out, in)
    b = zeros(Float32, out)
    logσb = -6.0f0 * ones(Float32, out)
    return VariationalLinear(in, out, W, logσW, b, logσb, σ, 
                            reparamSample(W, logσW, η), 
                            reparamSample(b, logσb, η),η)
end

function update_samples(a::VariationalLinear)
    a.W_sample = reparamSample(a.W, a.logσW, a.η)
    a.b_sample = reparamSample(a.b, a.logσb, a.η)
end

function (a::VariationalLinear)(x::AbstractArray{T}) where T
    update_samples(a)
    out = a.σ.(a.W_sample * x .+ reshape(a.b_sample, a.out, 1, a.η))
    return out 
end

"""
BayesNN 
"""
struct BayesNN{T <: Array{VariationalLinear,1},P <: Zygote.Params, N<:Integer}
	layers::T
    θ::P
    η::N
end
function BayesNN(in::Integer, out::Integer, η::Integer, num_layers::Integer, num_hidden::Integer, σ=relu)
    # putting layers into array
    layers = VariationalLinear{<:Function}[VariationalLinear(in, num_hidden, η, σ),]
	append!(layers, [VariationalLinear(num_hidden, num_hidden, η, σ) for i in 1:(num_layers - 1)])
	append!(layers, [VariationalLinear(num_hidden, out, η),])
	# collecting into parameter array 
	P = [layers[1].W, layers[1].b, layers[1].logσW, layers[1].logσb]
	(L -> append!(P, [L.W, L.b, L.logσW, L.logσb])).(layers[2:end])
	return BayesNN(layers, Flux.params(P), η)
end
(b::BayesNN)(x) = foldl((x, b) -> b(x), b.layers, init=x)

function log_likelihood(x, y, noise, model)
	return sum(-0.5f0 / noise^2 .* (model(x) .- y).^2) / model.η
end 

function log_normal(x, μ, logσ)
    -0.5f0 * (x .- μ).^2 ./ exp.(logσ).^2  .- logσ .- log(sqrt(2.0f0 * π))
end

function mean_log_prob(m)
    sum(sum(m, dims=3))
end

function kl_divergence(model)
    logq = 0.0f0
    logp = 0.0f0
    for layer in model.layers
        logq += mean_log_prob(log_normal(layer.W_sample, layer.W, layer.logσW))/model.η
        logp += mean_log_prob(log_normal(layer.W_sample, 0.0f0, 1.0f0))/model.η
    end
    return logq - logp
end

function main()
    @info "Start"

    model = BayesNN(1, 1, 20, 3, 100)

    # making dummy-data 
    bs = 100;
    epochs = 10000;
    period = 2.0f0;
    noise = 1.0f0; 
    x = reshape(collect(LinRange(0.0, 10.0, bs)), (1, bs));
    y = x .* 2 .* sin.(x .* (2 * π / period)) 
    yt = y .+ randn(size(x)...) * noise

    # converting to float32
    x = convert(Array{Float32}, x)
    y = convert(Array{Float32}, y)
    yt = convert(Array{Float32}, yt)

    θ = model.θ

    opt = Flux.Optimise.ADAM(1e-3)

    for i in 1:epochs
        gs = gradient(θ) do 
            return -log_likelihood(x, y, noise, model) + kl_divergence(model)
        end
        Flux.Optimise.update!(opt, θ, gs)
        if i % 1000 == 0 # change 1 to higher number to compute and print less frequently
            @info "Epoch $i| log-likelihood $(-log_likelihood(x, y, noise, model)) | kl-div $(kl_divergence(model))"
        end
    end

    @info "Done!"
    theme(:dark) 
    plot(x',yt', seriestype=:scatter, alpha=0.4)
    plot!(x',y')
    plot!(x',model(x)[1,:,:], legend = false , alpha=0.8 )
end

Flux is slower than PyTorch in CPU mode, Conv2d is even 2x slower (issue), but in GPU mode they’re comparable.

3 Likes

Another part of the problem is the speed of log and exp in Julia. I have a PR that should fix exp for 1.6 (not yet merged but close).

7 Likes

The batched matrix multiply will allocate a some extra memory due to the call to cat. Try preallocating the result and use mul! to store into it. Alternatively, you can try to use something like Tullio.jl to speed it up further.

You also don’t need to allocate the full array of random variables, a call to randn.() might be faster (dot not needed since you use @.).

3 Likes

Performance of scalar randn() is currently not as good as it could be, see https://github.com/JuliaLang/julia/pull/37234.
If you really need to get the few additional percent of performance, you can “patch” the corresponding function in the Random module on-the-fly using @eval (see linked issue).

1 Like

Thanks for letting me know I assumed they would have similar performance on CPU. Thankfully the work I have started learning Julia for can be offloaded to GPU.

Thanks for the suggestions!

I had avoided mul! because Zygote can’t use AD when mutating arrays (unless there’s an easy way around this that I’ve missed?)

Thanks this is really good to know. I didn’t realize interpolating arguments would speed things up.

Great news thanks!

You should be able to use batched_mul for this, but someone needs to merge #191 first:

julia> using NNlib # v0.7.4

julia> batched_mul(rand(2,3,5), rand(3,4,1))
ERROR: DimensionMismatch("batch size mismatch")

Edit – in reply to a deleted comment.

And xref a recent thread. However, whether this is the bottleneck here or not I’ve no idea.

2 Likes