Training is 500x slower than inference for a custom loss function

I am trying to use gradient descent to optimize a matrix (a single layer neural network) using a custom loss function. The loss function is a sum of a Gaussian-kernel-embedded maximum mean discrepancy plus the L1 norm of the model weights (matrix elements). It is incredibly slow to train and I don’t know why. Each step takes 100x longer than it should. I have gone through the Flux optimization tips but I am still seeing this massive issue in performance.

Any help would be appreciated.


Here is the output of my minimum working example (MWE):

❯ julia mwe.jl
[ Info: Train with desired loss function
134.781981 seconds (1.03 G allocations: 115.077 GiB, 22.93% gc time, 18.78% compilation time)
[ Info: Train with MSE loss function
  0.277686 seconds (1.70 M allocations: 84.768 MiB, 99.89% compilation time)
[ Info: Benchmark desired loss function
  0.227674 seconds (396.55 k allocations: 20.628 MiB, 43.10% compilation time)

and here is my MWE:

using Flux
using Flux.Optimise: Adam, train!
using Flux.Data: DataLoader
using LinearAlgebra
using Random: AbstractRNG, default_rng
using StatsBase: sample

function mmd(x, y; σ=1)
    T = eltype(x)
    M = length(x)
    N = length(y)

    mmd = zero(T)
    running_total = zero(T)

    for i in 1:M, j in 1:M
        running_total += gaussian_kernel(x[i], x[j]; σ=σ)
    end
    mmd += (running_total / convert(T, M)^convert(T, 2))

    running_total = zero(T)
    for i in 1:M, j in 1:N
        running_total += gaussian_kernel(x[i], y[j]; σ=σ)
    end
    mmd -= (convert(T, 2) / convert(T, M * N) * running_total)

    running_total = zero(T)
    for i in 1:N, j in 1:N
        running_total += gaussian_kernel(y[i], y[j]; σ=σ)
    end
    mmd += (running_total / convert(T, N)^convert(T, 2))

    return mmd
end

function gaussian_kernel(x, y; σ=1)
    return exp(
        -one(typeof(x)) / (oftype(x / 1, 2) * oftype(x / 1, σ)^oftype(x / 1, 2)) *
        abs(x - y)^oftype(x / 1, 2),
    )
end

function mmd_loss(x, x̂; σs=[1])
    return sum(mmd(x, x̂; σ=σ) for σ in σs)
end

function generate_data(rng::AbstractRNG, n_samples::T, m::T, n::T, p::T) where {T<:Integer}
    # Generate a Gaussian random matrix
    H = randn(rng, Float32, n, n_samples) ./ p
    # Set all but p indices in each row to zero
    for h in eachcol(H)
        indices = sample(1:n, n - p; replace=false)
        h[indices] .= 0
    end
    # Rescale
    H /= sqrt(norm(H) / n_samples)

    # Compute the label data
    U = randn(rng, Float32, m, n_samples)
    for u in eachcol(U)
        u .= u / norm(u)
    end
    return Float32.(H), Float32.(U)
end
function generate_data(n_samples::T, m::T, n::T, p::T) where {T<:Integer}
    return generate_data(default_rng(), n_samples, m, n, p)
end

invdB(x) = oftype(x / 1, 10)^(x / oftype(x / 1, 10))

function main()
    model = Dense(16 => 100, identity; bias=false)
    opt_state = Flux.setup(Adam(0.0001f0, (0.9f0, 0.999f0)), model)
    λ = one(eltype(model.weight))

    H, U = generate_data(20, 100, 16, 3)
    dataloader = DataLoader((H, U), batchsize=4)

    @info "Train with desired loss function"
    @time train!(model, dataloader, opt_state) do m, x, y
        this_mmd_loss = mmd_loss(m(x), y; σs = [2, 5, 10, 20, 40, 80])
        this_l1_loss = λ * norm(invdB.(model.weight), 1)
        this_mmd_loss + this_l1_loss
    end

    @info "Train with MSE loss function"
    @time train!(model, dataloader, opt_state) do m, x, y
        Flux.mse(m(x), y)
    end

    @info "Benchmark desired loss function"
    @time for (x, y) in dataloader
        this_mmd_loss = mmd_loss(model(x), y; σs = [2, 5, 10, 20, 40, 80])
        this_l1_loss = λ * norm(invdB.(model.weight), 1)
        this_mmd_loss + this_l1_loss
    end
end

main()

I suspect that your problem is that Zygote is very slow when differentiating loops. You could try to use Enzyme instead but that would require a bit of legwork.

I’m sorry I don’t have time to say more in detail but maybe this can at least point you in the right direction or start a discussion.

2 Likes

Exactly this. Writing performant code for Flux’s default AD (Zygote) is quite similar to writing for PyTorch or JAX: avoid loops, avoid scalar indexing and use vectorized functions (e.g. broadcasting, sum) wherever possible. If you really can’t do away with your very loopy, very scalar code, your options are a) write a rule which manually calculates the derivative, or b) use an alternative AD like ForwardDiff, ReverseDiff or Enzyme (all of which would accommodate it much better).

3 Likes

I made a macro for differentiating broadcast-and-reduce things like this efficiently. It gets about 1000x here.

using Tullio, ForwardDiff

function mmd(x, y; σ=1)
    T = eltype(x)
    M = length(x)
    N = length(y)

    # This uses ForwardDiff in the loop to work out two gradients
    @tullio running_total := gaussian_kernel(x[i], x[j]; σ=σ) grad=Dual
    mmd = running_total / M^2

    @tullio running_total2 := gaussian_kernel(x[i], y[j]; σ=σ) grad=Dual
    mmd -= 2 * running_total2 / (M * N)

    @tullio running_total3 := gaussian_kernel(y[i], y[j]; σ=σ) grad=Dual
    mmd += running_total3 / N^2

    return mmd
end

#=

julia> main() # with @tullio grad=Dual
[ Info: Train with desired loss function
  1.018355 seconds (13.17 k allocations: 1.284 MiB)
[ Info: Train with MSE loss function
  0.000209 seconds (424 allocations: 124.219 KiB)
[ Info: Benchmark desired loss function
  0.126865 seconds (8.29 k allocations: 497.375 KiB)

=#

function mmd(x, y; σ=1)
    T = eltype(x)
    M = length(x)
    N = length(y)
    pre = T(-1/(2 * σ^2))

    # Here the macro can see the function & compute symbolic derivatives:
    @tullio running_total := exp(pre * abs2(x[i] - x[j]))
    mmd = running_total / M^2

    @tullio running_total2 := exp(pre * abs2(x[i] - y[j]))
    mmd -= 2 * running_total2 / (M * N)

    @tullio running_total3 := exp(pre * abs2(y[i] - y[j]))
    mmd += running_total3 / N^2

    return mmd
end

#=

julia> main()  # with @tullio symbolic derivatives
[ Info: Train with desired loss function
  0.178408 seconds (13.17 k allocations: 1.282 MiB)
[ Info: Train with MSE loss function
  0.000200 seconds (424 allocations: 124.219 KiB)
[ Info: Benchmark desired loss function
  0.064159 seconds (8.29 k allocations: 493.500 KiB)

# this prints symbolic results, before applying CSE.
julia> @tullio running_total := exp(pre * abs2(x[i] - x[j])) verbose=true
┌ Info: symbolic gradients
│   inbody =
│    2-element Vector{Any}:
│     :(𝛥x[i] = 𝛥x[i] + 𝛥ℛ[1] * conj((((x[i] - x[j]) + (x[i] - x[j])) * pre) * exp(pre * abs2(x[i] - x[j]))))
└     :(𝛥x[j] = 𝛥x[j] + 𝛥ℛ[1] * conj((-(((x[i] - x[j]) + (x[i] - x[j]))) * pre) * exp(pre * abs2(x[i] - x[j]))))
┌ Warning: can't parallelise this gradient, no shared indices (symbolic gradient)
└ @ Tullio ~/.julia/packages/Tullio/NGyNM/src/macro.jl:1061
┌ Info: threading threshold (from cost = 21)
└   block = 12484
ERROR: UndefVarError: `x` not defined

=#

function mmd(x, y; σ=1)
    T = eltype(x)
    M = length(x)
    N = length(y)

    # Variant using broadcasting, which avoids indexing in a loop,
    # but still materialises arrays of size N^2, seems about as slow!

    running_total = sum(gaussian_kernel.(vec(x), vec(x)'; σ=σ))
    mmd = running_total / M^2

    running_total2 = sum(gaussian_kernel.(vec(x), vec(y)'; σ=σ))
    mmd -= 2 * running_total2 / (M * N)

    running_total3 = sum(gaussian_kernel.(vec(y), vec(y)'; σ=σ))
    mmd += running_total3 / N^2

    return mmd
end

#=

julia> main()  # broadcasting variant
325.794644 seconds (1.10 G allocations: 39.603 GiB, 39.69% gc time, 1.11% compilation time: 18% of which was recompilation)

julia> main()  # with code from above, same computer, first run
[ Info: Train with desired loss function
307.520095 seconds (956.91 M allocations: 111.810 GiB, 54.09% gc time, 2.85% compilation time)

=#

While avoiding accidental type changes is good, this misses that literal x^2 is x*x, while x^1.0 is much slower. And that abs(x)^2 is also much slower than abs2(x). Not the big effect here of course.

5 Likes

I have implemented MMD few years ago and put it into this package.

It is written without for-loops and should be therefore performant. I have not touched the library for two years, but I have used it few weeks ago without an issue. If you will find a problem, ping me or file an issue (and then ping me).

I hope it helps.

Tomas

2 Likes

Thank you all for your help! Especially @mcabbott for giving my code a test with Tullio and for @Tomas_Pevny’s integral probably measures implementations.

Consider this solved.