I am trying to use gradient descent to optimize a matrix (a single layer neural network) using a custom loss function. The loss function is a sum of a Gaussian-kernel-embedded maximum mean discrepancy plus the L1 norm of the model weights (matrix elements). It is *incredibly* slow to train and I don’t know why. Each step takes 100x longer than it should. I have gone through the Flux optimization tips but I am still seeing this massive issue in performance.

Any help would be appreciated.

Here is the output of my minimum working example (MWE):

```
❯ julia mwe.jl
[ Info: Train with desired loss function
134.781981 seconds (1.03 G allocations: 115.077 GiB, 22.93% gc time, 18.78% compilation time)
[ Info: Train with MSE loss function
0.277686 seconds (1.70 M allocations: 84.768 MiB, 99.89% compilation time)
[ Info: Benchmark desired loss function
0.227674 seconds (396.55 k allocations: 20.628 MiB, 43.10% compilation time)
```

and here is my MWE:

```
using Flux
using Flux.Optimise: Adam, train!
using Flux.Data: DataLoader
using LinearAlgebra
using Random: AbstractRNG, default_rng
using StatsBase: sample
function mmd(x, y; σ=1)
T = eltype(x)
M = length(x)
N = length(y)
mmd = zero(T)
running_total = zero(T)
for i in 1:M, j in 1:M
running_total += gaussian_kernel(x[i], x[j]; σ=σ)
end
mmd += (running_total / convert(T, M)^convert(T, 2))
running_total = zero(T)
for i in 1:M, j in 1:N
running_total += gaussian_kernel(x[i], y[j]; σ=σ)
end
mmd -= (convert(T, 2) / convert(T, M * N) * running_total)
running_total = zero(T)
for i in 1:N, j in 1:N
running_total += gaussian_kernel(y[i], y[j]; σ=σ)
end
mmd += (running_total / convert(T, N)^convert(T, 2))
return mmd
end
function gaussian_kernel(x, y; σ=1)
return exp(
-one(typeof(x)) / (oftype(x / 1, 2) * oftype(x / 1, σ)^oftype(x / 1, 2)) *
abs(x - y)^oftype(x / 1, 2),
)
end
function mmd_loss(x, x̂; σs=[1])
return sum(mmd(x, x̂; σ=σ) for σ in σs)
end
function generate_data(rng::AbstractRNG, n_samples::T, m::T, n::T, p::T) where {T<:Integer}
# Generate a Gaussian random matrix
H = randn(rng, Float32, n, n_samples) ./ p
# Set all but p indices in each row to zero
for h in eachcol(H)
indices = sample(1:n, n - p; replace=false)
h[indices] .= 0
end
# Rescale
H /= sqrt(norm(H) / n_samples)
# Compute the label data
U = randn(rng, Float32, m, n_samples)
for u in eachcol(U)
u .= u / norm(u)
end
return Float32.(H), Float32.(U)
end
function generate_data(n_samples::T, m::T, n::T, p::T) where {T<:Integer}
return generate_data(default_rng(), n_samples, m, n, p)
end
invdB(x) = oftype(x / 1, 10)^(x / oftype(x / 1, 10))
function main()
model = Dense(16 => 100, identity; bias=false)
opt_state = Flux.setup(Adam(0.0001f0, (0.9f0, 0.999f0)), model)
λ = one(eltype(model.weight))
H, U = generate_data(20, 100, 16, 3)
dataloader = DataLoader((H, U), batchsize=4)
@info "Train with desired loss function"
@time train!(model, dataloader, opt_state) do m, x, y
this_mmd_loss = mmd_loss(m(x), y; σs = [2, 5, 10, 20, 40, 80])
this_l1_loss = λ * norm(invdB.(model.weight), 1)
this_mmd_loss + this_l1_loss
end
@info "Train with MSE loss function"
@time train!(model, dataloader, opt_state) do m, x, y
Flux.mse(m(x), y)
end
@info "Benchmark desired loss function"
@time for (x, y) in dataloader
this_mmd_loss = mmd_loss(model(x), y; σs = [2, 5, 10, 20, 40, 80])
this_l1_loss = λ * norm(invdB.(model.weight), 1)
this_mmd_loss + this_l1_loss
end
end
main()
```