Reactant.jl compile times and memory usage

I’m trying to do some ML experiments using Lux and Reactant. Without posting all of the code, I was wondering if I’m using Reactant properly, because I’m getting compile times of around 3 minutes, and gigabytes of memory usage for a single function, so that compiling another one can sometimes crash my WSL.

function apply_no_prealloc(model, u_batch, parameters, state)
    (;rhel_layer, dropout, output_layer) = model
    layer_parameters = parameters.rhel_layer
    layer_state = state.rhel_layer

    layer_state = layer_no_prealloc(rhel_layer, u_batch, layer_parameters, layer_state)
    y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
    y = cat(u_batch, y; dims=1)  # Concatenate input and hidden along feature dim (N_in + N, T, B)

    if output_mode(model) isa Classification
        # Classification mode: mean over time, then linear layer
        y = mean(y; dims=2)  # (N_in + N, B)
        y = dropdims(y; dims=2)  # (N_in + N, B) -> (N_in + N, B)
        y = output_layer(y, parameters.output_layer, state.output_layer)  # (output_dim, B)
    else
        # CODE
    end

    updated_state = (; state..., dropout = dropout_state)
    return y, updated_state
end

function layer_unallocated(layer, input_sequence, params, state)
    # input is in (N_in, t, B)
    N_in, DT, B = size(input_sequence)
    
    (;W_in) = params

    # Apply the input layer to all timesteps and batches
    W_dot_us =  reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B)  # (N, DT*B), flattened batches and time for GEMM

    return apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
end

function apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
    _, Tlen, _ = size(W_dot_us)
    for t in 1:Tlen
        @inline leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
    end
    return state
end

function leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
    (;ϕ, π) = state
    (;W, b, α, β, ϵ) = params
    dt = layer.dt

    half_dt = dt / 2

    W_in_u = @view W_dot_us[:, t, :]

    ϕ_t = @view ϕ[:, t, :]

    kinetic_grad = grad_state_kinetic_hamiltonian(π)
    ϕhalf = ϕ_t .+ half_dt .* kinetic_grad

    grad = grad_state_potential_hamiltonian_no_prealloc(ϕhalf, W_in_u, W, b, α, β)
    π .-= dt .* grad

    kinetic_grad = grad_state_kinetic_hamiltonian(π)
    ϕ_t .= ϕhalf .+ half_dt .* kinetic_grad
    π .+= ϵ

    return state
end

function grad_state_potential_hamiltonian_no_prealloc(ϕ::AbstractArray{T}, W_in_u, W, b, alpha, beta) where T
    # Gradient of norm term: ∂/∂ϕ [1/2 * α * ϕ²] = α * ϕ

    tβϕ = @. tanh(beta * ϕ)
    state_grad = @. alpha * ϕ

    sech2β = @. T(1) - tβϕ^2  # sech²(βϕ)
    
    W_tanh_ϕ = W*tβϕ
    WT_tanh_ϕ = W'*tβϕ

    state_grad += @. (T(1) / beta) * T(0.5) * sech2β * (W_tanh_ϕ + WT_tanh_ϕ)

    # Gradient of bias term: ∂/∂ϕ [b^T tanh(ϕ)] = b * sech²(ϕ)
    state_grad += @. b * (T(1) - tanh(ϕ)^2)

    # Gradient of input term: ∂/∂ϕ [tanh(ϕ)^T W_in_u] = W_in_u * sech²(ϕ)
    state_grad += @. W_in_u * (T(1) - tanh(ϕ)^2)
    return state_grad
end

I’m trying to do

f = @compile apply_no_prealloc(model, u, state, params)

for some model I have. We had something similar before in jax which seems quite snappy to compile.

try to use @trace on the biggest loops otherwize they get unrolled and lead to enormous codegen, note that not everything works with tracing though, the one I think of is for t in 1:Tlen

Ah that makes sense, so the whole for loop was unrolled.. Thanks for that!

If I add @trace to the for loop I’m getting this error during compiling, however:

ERROR: Reactant.NoFieldMatchError(Base.ReshapedArray{Reactant.TracedRNumber{Float32}, 3, Reactant.TracedRArray{Float32, 2}, Tuple{}}, Base.ReshapedArray{Reactant.TracedRNumber{Float32}, 3, Reactant.TracedRArray{Float32, 2}, Tuple{}}, Union{TypeVar, Type}[Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Int64}, Reactant.TracedRNumber{Int64}, Reactant.TracedRNumber{Int64}}, Tuple{}])
Stacktrace:
  [1] traced_type_inner(T::Type, seen::Dict{…}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/.julia/packages/Reactant/lu2GU/src/Tracing.jl:764
  [2] traced_type_inner(PT::Type{…}, seen::Dict{…}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/.julia/packages/Reactant/lu2GU/src/Tracing.jl:570
  [3] ERROR: MethodError: no method matching namemap(::Type{Reactant.TraceMode})
The function `namemap` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  namemap(::Type{Reactant.TraceMode}) (method too new to be called from this world context.)
   @ Reactant Enums.jl:215
  namemap(::Type{Enzyme.Compiler.ActivityState}) (method too new to be called from this world context.)
   @ Enzyme Enums.jl:215
  namemap(::Type{HTTP.Cookies.SameSite}) (method too new to be called from this world context.)
   @ HTTP Enums.jl:215
  ...

Stacktrace:
  [1] _symbol(x::Reactant.TraceMode)
    @ Base.Enums ./Enums.jl:37
  [2] show(io::IOContext{IOBuffer}, x::Reactant.TraceMode)
    @ Base.Enums ./Enums.jl:45
  [3] show_typeparams(io::IOContext{IOBuffer}, env::Core.SimpleVector, orig::Core.SimpleVector, wheres::Vector{TypeVar})
    @ Base ./show.jl:722
  [4] show_datatype(io::IOContext{IOBuffer}, x::DataType, wheres::Vector{TypeVar})
    @ Base ./show.jl:1181
  [5] show_datatype
    @ ./show.jl:1089 [inlined]
  [6] _show_type(io::IOContext{IOBuffer}, x::Type)
    @ Base ./show.jl:973
  [7] show(io::IOContext{IOBuffer}, x::Type)
    @ Base ./show.jl:965
  [8] sprint(f::Function, args::Type; context::IOContext{IOBuffer}, sizehint::Int64)
    @ Base ./strings/io.jl:112
  [9] sprint
    @ ./strings/io.jl:107 [inlined]
 [10] #print_type_bicolor#659
    @ ./show.jl:2718 [inlined]
 [11] show_tuple_as_call(out::IOContext{Base.TTY}, name::Symbol, sig::Type; demangle::Bool, kwargs::Nothing, argnames::Vector{Symbol}, qualified::Bool, hasfirst::Bool)
    @ Base ./show.jl:2585
 [12] show_tuple_as_call
    @ ./show.jl:2552 [inlined]
 [13] show_spec_sig(io::IOContext{Base.TTY}, m::Method, sig::Type)
    @ Base.StackTraces ./stacktraces.jl:265
 [14] show_spec_linfo(io::IOContext{Base.TTY}, frame::Base.StackTraces.StackFrame)
    @ Base.StackTraces ./stacktraces.jl:232
 [15] print_stackframe(io::IOContext{Base.TTY}, i::Int64, frame::Base.StackTraces.StackFrame, n::Int64, ndigits_max::Int64, modulecolor::Symbol)
    @ Base ./errorshow.jl:762
 [16] print_stackframe(io::IOContext{Base.TTY}, i::Int64, frame::Base.StackTraces.StackFrame, n::Int64, ndigits_max::Int64, modulecolordict::IdDict{Module, Symbol}, modulecolorcycler::Base.Iterators.Stateful{Base.Iterators.Cycle{Vector{Symbol}}, Union{Nothing, Tuple{Symbol, Int64}}})
    @ Base ./errorshow.jl:729
 [17] show_full_backtrace(io::IOContext{Base.TTY}, trace::Vector{Any}; print_linebreaks::Bool)
    @ Base ./errorshow.jl:628
 [18] show_full_backtrace
    @ ./errorshow.jl:621 [inlined]
 [19] show_backtrace(io::IOContext{Base.TTY}, t::Vector{Base.StackTraces.StackFrame})
    @ Base ./errorshow.jl:823
 [20] showerror(io::IOContext{Base.TTY}, ex::Reactant.NoFieldMatchError, bt::Vector{Base.StackTraces.StackFrame}; backtrace::Bool)
    @ Base ./errorshow.jl:99
 [21] showerror(io::IOContext{Base.TTY}, ex::Reactant.NoFieldMatchError, bt::Vector{Base.StackTraces.StackFrame})
    @ Base ./errorshow.jl:95
 [22] display_repl_error(io::Base.TTY, stack::VSCodeServer.EvalErrorStack; unwrap::Bool)
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/repl.jl:261
 [23] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/eval.jl:224
 [24] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/repl.jl:276
 [25] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/eval.jl:179
 [26] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/repl.jl:38
 [27] #67
    @ ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
 [28] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
    @ Base.CoreLogging ./logging/logging.jl:524
 [29] with_logger
    @ ./logging/logging.jl:635 [inlined]
 [30] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/eval.jl:263
 [31] #invokelatest#2
    @ ./essentials.jl:1055 [inlined]
 [32] invokelatest(::Any)
    @ Base ./essentials.jl:1052
 [33] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.156.1/scripts/packages/VSCodeServer/src/eval.jl:34

julia> 

thats weird can you try to restart julia ? Also as I said @trace is very very experimental so as always minimal runnable reproducer and issue on Reactant.jl

From a fresh restart I’m getting the same error. If I add the macro to the whole function call @trace apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us), it works, but compiling is still very slow and uses a lot of memory (almost 10gb). It does seem like adding @trace at this location shaved off about 30 seconds (160s->130s) of compilation time

that’s not bad I guess, btw I wonder why you care so much making this minimal Reactant.jl is premarly usefull when considering hour long training normally unless you’re making a library over Lux.
ps: you can use @code_hlo instead of @compile to see everything going on in the code

Well that is exactly why, I’m running it in a training loop. Anyway, is 10gb really expected?

depends on the amount of data / parameters you use I think, also its actually not that of a big deal unless you’re limited in memory on your machine, Reactant cancel the gc, so it shouldn’t affect the perfs as much as julia Base

Do you happen to have a full example I can run? It doesn’t have to me minimal. 2 mins for first compile is not unexpected (there is ongoing work to reduce that). Though the 10gb memory usage is definitely suspicious

Here’s a running example. The function is actually compiling to something more like 7GB, not 10GB. I underestimated how much my data was taking. Compile time is close to 2-3 minutes with @trace in the position that it doesn’t error. On the for loop it still errors.