Reactant.jl recompiles function each time it is called

I am trying to use an optimiser from Manopt.jl to train a Lux.jl neural network using Reactant.jl. It is orders of magnitude slower than without Reactant.jl. Does anyone know how to fix it?

The compiled functions, when benchmarked on their own are very fast. But when I try to use them inside an optimiser, they are constantly recompiled. I’ve gathered that it might be the input type of the function is changing, but surely an optimiser would call with consistent input types, or at least a finite number of them, so compilation would only be required a couple of times at most.

The code is here

using Reactant
using Zygote, Enzyme, Lux
using Random
using ComponentArrays
using ADTypes
using BenchmarkTools

using Manifolds
using Manopt

import ManifoldsBase.inner
import ManifoldsBase.norm
import Base.-

inner(M::Euclidean, p, X, Y) = sum(X .* Y)
norm(M::Euclidean, p, X) = (sum(X .* X) ^ (1/2))
-(x::ConcretePJRTNumber{Float64, 1}) = ConcretePJRTNumber(0 - x)

Random.seed!(32)

Reactant.set_default_backend("cpu")
xdev = reactant_device()

layer_dimension = 8
model = Chain(Dense(1, layer_dimension, swish),
              Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +),
                    Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +), Dense(layer_dimension, 1))))
ps, st = Lux.setup(Random.default_rng(), model);
# make it into Float64
ps = ps |> ComponentArray .|> Float64

ps_ra = ps |> xdev
st_ra = st |> xdev

xx = randn(1, 100)
xx_test = randn(1, 100)

yy = (xx .+ 1.0) .^ 2.0 .- 1.0
yy_test = (xx_test .+ 1.0) .^ 2.0 .- 1.0

xx_ra = xx |> xdev
yy_ra = yy |> xdev

xx_test_ra = xx |> xdev
yy_test_ra = yy_test |> xdev

model_compiled = @compile model(xx_ra, ps_ra, Lux.testmode(st_ra))

function Loss(params, state, model, xx, yy)
    return sum(abs2, (first(model(xx, params, state)) .- yy) ) / 2 # + 0.01 * sum(abs2, params) / 2
end

function Gradient(params, state, model, xx, yy)
    residual = first(model(xx, params, state)) - yy
    gr1 = vector_jacobian_product(z -> first(model(xx, z, state)), AutoEnzyme(), params, residual)
    return gr1 #+ 0.01 .* params
end

function Jacobian(params, state, model, xx)
    return batched_jacobian(z -> first(model(z, params, state)), AutoEnzyme(), xx)
end

Loss_compiled = @compile Loss(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Gradient_compiled = @compile Gradient(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Jacobian_compiled = @compile Jacobian(ps_ra, Lux.testmode(st_ra), model, xx_ra)

Manifold = Euclidean(length(ps_ra))

res = trust_regions(
    Manifold,
    (Manifold, ps_ra) -> Loss_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
    (Manifold, ps_ra) -> Gradient_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
    ps_ra;
    stopping_criterion=StopAfterIteration(1000) | StopWhenGradientNormLess(1e-6),
    debug = [:Iteration, :Cost, :GradientNorm, " | ", 1, "\n", :Stop]
)

I think it’s because you’re in global cab you try making the compiled function const ?

Also, what makes you think it’s recompiled, I don’t know if Manopt was made compatible with Reactant and your Reactant array may hit weird path, what happens if you compile the full call instead of the inner functions?