Reactant.jl recompiles function each time it is called

I am trying to use an optimiser from Manopt.jl to train a Lux.jl neural network using Reactant.jl. It is orders of magnitude slower than without Reactant.jl. Does anyone know how to fix it?

The compiled functions, when benchmarked on their own are very fast. But when I try to use them inside an optimiser, they are constantly recompiled. I’ve gathered that it might be the input type of the function is changing, but surely an optimiser would call with consistent input types, or at least a finite number of them, so compilation would only be required a couple of times at most.

The code is here

using Reactant
using Zygote, Enzyme, Lux
using Random
using ComponentArrays
using ADTypes
using BenchmarkTools

using Manifolds
using Manopt

import ManifoldsBase.inner
import ManifoldsBase.norm
import Base.-

inner(M::Euclidean, p, X, Y) = sum(X .* Y)
norm(M::Euclidean, p, X) = (sum(X .* X) ^ (1/2))
-(x::ConcretePJRTNumber{Float64, 1}) = ConcretePJRTNumber(0 - x)

Random.seed!(32)

Reactant.set_default_backend("cpu")
xdev = reactant_device()

layer_dimension = 8
model = Chain(Dense(1, layer_dimension, swish),
              Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +),
                    Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +), Dense(layer_dimension, 1))))
ps, st = Lux.setup(Random.default_rng(), model);
# make it into Float64
ps = ps |> ComponentArray .|> Float64

ps_ra = ps |> xdev
st_ra = st |> xdev

xx = randn(1, 100)
xx_test = randn(1, 100)

yy = (xx .+ 1.0) .^ 2.0 .- 1.0
yy_test = (xx_test .+ 1.0) .^ 2.0 .- 1.0

xx_ra = xx |> xdev
yy_ra = yy |> xdev

xx_test_ra = xx |> xdev
yy_test_ra = yy_test |> xdev

model_compiled = @compile model(xx_ra, ps_ra, Lux.testmode(st_ra))

function Loss(params, state, model, xx, yy)
    return sum(abs2, (first(model(xx, params, state)) .- yy) ) / 2 # + 0.01 * sum(abs2, params) / 2
end

function Gradient(params, state, model, xx, yy)
    residual = first(model(xx, params, state)) - yy
    gr1 = vector_jacobian_product(z -> first(model(xx, z, state)), AutoEnzyme(), params, residual)
    return gr1 #+ 0.01 .* params
end

function Jacobian(params, state, model, xx)
    return batched_jacobian(z -> first(model(z, params, state)), AutoEnzyme(), xx)
end

Loss_compiled = @compile Loss(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Gradient_compiled = @compile Gradient(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Jacobian_compiled = @compile Jacobian(ps_ra, Lux.testmode(st_ra), model, xx_ra)

Manifold = Euclidean(length(ps_ra))

res = trust_regions(
    Manifold,
    (Manifold, ps_ra) -> Loss_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
    (Manifold, ps_ra) -> Gradient_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
    ps_ra;
    stopping_criterion=StopAfterIteration(1000) | StopWhenGradientNormLess(1e-6),
    debug = [:Iteration, :Cost, :GradientNorm, " | ", 1, "\n", :Stop]
)

I think it’s because you’re in global cab you try making the compiled function const ?

Also, what makes you think it’s recompiled, I don’t know if Manopt was made compatible with Reactant and your Reactant array may hit weird path, what happens if you compile the full call instead of the inner functions?

I tried putting the whole thing in a module and making the compiled functions ‘const’. Did not help. It seems that Reactant wants to compile the inner workings of Manopt in every iteration. The optimisation does work, but it is really slow.

Yes it’s because it’s working with the Reactant ps array so every operation ends up falling on the AbstractArray interface.

You need to try to compile the whole thing praying there is no weird thing going on.

Reactant only performs compilation when you specify @compile or @jit, so [not knowing anything about Manopt] at first glance that shouldn’t be the case. What happens if you run this under a profiler, can you see where the time is being spent, and potentially why?

Also definitely open a PR for the minus override! Though honestly I wouldn’t return a new ConcretePJRTNumber, and just have it as:

-(x::ConcretePJRTNumber{Float64, 1}) = 0 - x

note that for cpu heavy stuff like this Lux is pretty good without Reactant though

using Lux, Reactant, Enzyme
using BenchmarkTools
import Random
Reactant.set_default_backend("cpu")

layer_dimension = 8
model = Chain(Dense(1, layer_dimension, swish),
              Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +),
                    Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +), Dense(layer_dimension, 1))))
ps, st = Lux.setup(Random.default_rng(), model) |> f64;
dps = Enzyme.make_zero(ps)
xx = randn(1, 100) 


@btime $model($xx,$ps,$st)[1]
f(m,x,p,s) = sum(m(x,p,s)[1])
function get_grad!(dps,model,x,ps,st)
    Enzyme.make_zero!(dps)
    Enzyme.autodiff(Enzyme.Reverse,f,Const(model),Const(x),Duplicated(ps,dps),Const(st))
    dps
end
@btime get_grad!($dps,$model,$xx,$ps,$st);

xx_r = Reactant.to_rarray(xx)
ps_r = Reactant.to_rarray(ps)
st_r = Reactant.to_rarray(st)
dps_r = Reactant.to_rarray(dps)
model_c = @compile sync=true model(xx_r,ps_r,st_r)
get_grad_c = @compile sync=true get_grad!(dps_r,model,xx_r,ps_r,st_r)

@btime $model_c($xx_r,$ps_r,$st_r)
@btime $get_grad_c($dps_r,$model,$xx_r,$ps_r,$st_r)

gives

  13.197 μs (23 allocations: 32.88 KiB)
  46.571 μs (111 allocations: 127.98 KiB)
  11.578 μs (14 allocations: 496 bytes)
  21.314 μs (42 allocations: 1.25 KiB)

so yes 2x on the gradient would be cool to have I understand but its not that horrible either.
I hope one day we have a julia dialect MLIR so that Reactant can actually produce the julia function (for use on any kind of array) from its stablehlo when needed I think GitHub - maleadt/IRStructurizer.jl: Pattern-matching structured control flow in Julia's SSA IR. · GitHub isn’t far, we would loose xla of course but for gradient it would be very nice

I really think the profile here would be useful. In particular my guess is that however manifold is using the result is inadvertantly copying data a bunch to/from device.

For example

-(x::ConcretePJRTNumber{Float64, 1}) = ConcretePJRTNumber(0 - x)

is an example of this. 0 - x will copy data out from xla devices to a regular float, perform 0-x on it as a regular float. And then there’s a copy back into an xla device.

That’s partially why I recommended

-(x::ConcretePJRTNumber{Float64, 1}) = 0 - x

If everything is within a Reactant compile, it doesn’t matter and Reactant will fuse/optimize everything fully.

Outside of the compile is out of our control so it will matter how you use the results.