Excessive allocations in basic SciMLSensitivity example

The first example (copied below) in the SciMLSensitivity docs shows how to optimize the parameters of the Lotka-Volterra system to achieve a steady state. There’s no neural network, or anything, just an ODE system with 2 variables, evaluated at 101 time steps, and 4 parameters to optimize. The optimization is set to go at most 100 iterations. So it’s really not a large problem — certainly not anything like the size of problems I want to solve.

And yet, the allocations are pretty out of control: 233.69 M allocations totaling 11.576 GiB! (There’s no compilation happening.) It reports taking 4,578 function and gradient evaluations (no Hessians). If we count each of those function-gradient evaluations as 10 numbers, evaluated at 101 time steps, 4,578 times, there are still 50x more allocations and 2500x more bytes than I can see any possible need for. I’m sure there’s another factor from stages of the timestepper, and another factor of a few because adjoint sensitivity is hard, but not such large factors. And that’s only if literally every number requires a separate allocation. What’s going on?!

Profiling, it looks like most of the allocations (~95%) are caused by gradient computations in HagerZhang — which correlates very well with where the time is being spent. This example problem still runs in a reasonable amount of time (~30s), but I’m trying a much harder problem (with more evolved variables, more cases, and a NN with many more parameters), and even my simplest case is basically dead in the water (about 1/4 of which seems to be GC).

Is there anything I can do about those allocations?


Full example code
using OrdinaryDiffEq,
      Optimization, OptimizationPolyalgorithms, SciMLSensitivity,
      Zygote, Plots

function lotka_volterra!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = α * x - β * x * y
    du[2] = dy = -δ * y + γ * x * y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
sol = solve(prob, Tsit5())

# Plot the solution
using Plots
plot(sol)
savefig("LV_ode.png")

function loss(p)
    sol = solve(prob, Tsit5(), p = p, saveat = tsteps)
    loss = sum(abs2, sol .- 1)
    return loss
end

callback = function (state, l)
    display(l)
    pred = solve(prob, Tsit5(), p = state.u, saveat = tsteps)
    plt = plot(pred, ylim = (0, 6))
    display(plt)
    # Tell Optimization.solve to not halt the optimization. If return true, then
    # optimization stops.
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)

result_ode = Optimization.solve(optprob, PolyOpt(),
    callback = callback,
    maxiters = 100)

A little more profiling of those allocations led to the loss function:

function loss(p)
    sol = solve(prob, Tsit5(), p = p, saveat = tsteps)
    loss = sum(abs2, sol .- 1)
    return loss
end

Obviously, that sol .- 1 isn’t great because it will allocate a whole new array (if the AbstractArray interface to ODESolution is working correctly). So I tried

loss = sum(x->abs2(x-1), sol)

which was even worse somehow! (~30% longer to run, about the same number of allocations, but ~35% more memory allocated.) That feels like a bug, so I think I’ll open an issue.

Anyway, I tried the really dumb option of

loss = sum(x->sum(y->abs2(y-1), x), sol.u)

and that sped it up by a factor of 50x!!! Also, 17x fewer allocations and 15x less memory. I’m still not really sure why there are so many allocations, but that’s much better. I’m wondering if the derivative couldn’t figure out the AbstractArray interface to ODESolution.

The remaining allocations seem to come from derivatives of the solve, which I understand is hard to do, so I can probably move on.

I’m seeing a few non-const global variables being accessed by the methods of loss and the (non-const) callback, so fully admitting this is low-hanging fruit, what changes if you move everything into a main function except imports and methods that don’t access said variables?

1 Like

Good question. No change at all. :confused:

Does using StaticArrays help? (edit: ok I guess regardless of whether it helps here that has limited utility with more variables)

1 Like

Yeah, that cuts down the allocations and size by ~20% — though it also slows down the whole thing a tiny bit… I’m so confused.

My real problem may still have few enough variables (as opposed to parameters) that StaticArrays might help. Or maybe FixedSizeArrays.jl could help a bit…

Unfortunate. I’d check @code_warntype on that main function, but I doubt there’s any glaring type inference issues (abstract types on variables and fields of closures). Gotta be a deeper bottleneck if better type inference doesn’t help at all.

It looks like SciML intentionally uses a lot or runtime dispatch, so I guess that explains why putting everything inside a main function (or putting a const everywhere I can) won’t help at all. But at least there are function barriers within SciML where it looks like (just looking at the profiles) things do get pretty well inferred. I’m no expert on this, though, so I may be completely off.

1 Like

This whole thread seems to have a lot of misunderstandings of reverse mode AD. So a few things:

  1. In order to reverse, you need to store the forward pass. So Reverse mode pretty much requires allocations (though there are ways to handle this, this needs a bit more work). Unless you do BacksolveAdjoint, which is numerically unstable. For this case you could turn on BacksolveAdjoint, but the tutorial won’t do that because it’s generally unsafe and we generally don’t want to point people towards incorrect / unstable numerical methods.
  2. The loss = sum(x->sum(y->abs2(y-1), x), sol.u) is mostly a Zygote thing due to constructing the matrix for the reverse pass. Keeping everything as smaller arrays just helps Zygote a bit, though we can probably optimize that with a few views that don’t seem to inline.
  3. " I’m seeing a few non-const global variables being accessed by the methods of loss and the (non-const ) callback" these won’t matter because the majority of the time is going to be the adjoint pass
  4. “Does using StaticArrays help? (edit: ok I guess regardless of whether it helps here that has limited utility with more variables)” it can help somewhat, but you still need to allocate the holder to store “indefinitely” for when the reverse pass hits, so unlike forward modes you cannot just stack allocate it
  5. “It looks like SciML intentionally uses a lot or runtime dispatch” this is Zygote that is transforming dispatches to runtime, though most of the time should be in the adjoint pass, which is the function barrier and that should force it to mostly be fine.

With all of that said, I’m still in the middle of diagnosing a regression that happened and is highlighted in Bump autodiff tests by ChrisRackauckas · Pull Request #1275 · SciML/SciMLBenchmarks.jl · GitHub . If you’re interested in diving in I can help point to some places to look, since figuring out the reverse mode profile is… difficult is an oversimplification.

2 Likes

Thanks, Chris. I appreciate the input, and all the amazing work you do.

I’m just concerned that the obvious thing to do — which is also recommended in the first tutorial — can easily be made 62x faster, with 17x fewer allocations. (Side note: Using StaticArrays and Enzyme gets me to 103x faster and 30x fewer allocations.) This will be enough to keep me going, but I worry about other users (viz., one of our students) running into this unawares.

Also, I opened an issue about this here.

Don’t know that I’ll be much help, but I’m game to try!

First thing to try is Reactant tracing. We plan to put a ton of Reactant in these docs as that will help the performance outside of the adjoint, or at least it should. Then share a flamegraph profile with StatHTMLProfiler.jl, I think you can paste the whole zip in here to discuss?

Seems like its missing some tracing for now ?

using OrdinaryDiffEq,SciMLSensitivity,Reactant,Enzyme,Optimization,OptimizationPolyalgorithms

function lotka_volterra!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = α * x - β * x * y
    du[2] =  -δ * y + γ * x * y
end

Reactant.set_default_backend("cpu")
dev = to_rarray
Reactant.allowscalar(true)

# Initial condition
u0 = [1.0, 1.0] |> dev

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0 

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0] |> dev

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
dt = 1e-2
sol = @jit solve(prob,Tsit5(),dt=dt)
ERROR: MethodError: objects of type Base.RefValue{typeof(DiffEqBase.ODE_DEFAULT_NORM)} are not callable
The object of type `Base.RefValue{typeof(DiffEqBase.ODE_DEFAULT_NORM)}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/Efoni/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::Base.RefValue{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:875
  [3] calculate_residuals
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:11 [inlined]
  [4] (::Nothing)(none::typeof(DiffEqBase.calculate_residuals), none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Base.RefValue{…}, none::Reactant.TracedRNumber{…})
    @ Reactant ./<missing>:0
  [5] calculate_residuals
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:11 [inlined]
  [6] call_with_reactant(::typeof(DiffEqBase.calculate_residuals), ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Base.RefValue{…}, ::Reactant.TracedRNumber{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
  [7] make_mlir_fn(f::typeof(DiffEqBase.calculate_residuals), args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:330
  [8] elem_apply(::Function, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Base.RefValue{…}, ::Reactant.TracedRArray{…})
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:1082
  [9] _copyto!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:779 [inlined]
 [10] (::Nothing)(none::typeof(Reactant.TracedRArrayOverrides._copyto!), none::Reactant.TracedRArray{…}, none::Base.Broadcast.Broadcasted{…})
    @ Reactant ./<missing>:0
 [11] getproperty
    @ ./Base.jl:49 [inlined]
 [12] size
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:489 [inlined]
 [13] axes
    @ ./abstractarray.jl:98 [inlined]
 [14] _copyto!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:772 [inlined]
 [15] call_with_reactant(::typeof(Reactant.TracedRArrayOverrides._copyto!), ::Reactant.TracedRArray{…}, ::Base.Broadcast.Broadcasted{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [16] materialize!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:730 [inlined]
 [17] materialize!
    @ ./broadcast.jl:880 [inlined]
 [18] fast_materialize!
    @ ~/.julia/packages/FastBroadcast/wfdTr/src/FastBroadcast.jl:279 [inlined]
 [19] calculate_residuals!
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:97 [inlined]
 [20] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:224 [inlined]
 [21] (::Nothing)(none::typeof(OrdinaryDiffEqCore.perform_step!), none::OrdinaryDiffEqCore.ODEIntegrator{…}, none::OrdinaryDiffEqTsit5.Tsit5Cache{…}, none::Bool)
    @ Reactant ./<missing>:0
 [22] getproperty
    @ ~/.julia/packages/SciMLBase/DbVzk/src/integrator_interface.jl:487 [inlined]
 [23] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:181 [inlined]
 [24] call_with_reactant(::typeof(OrdinaryDiffEqCore.perform_step!), ::OrdinaryDiffEqCore.ODEIntegrator{…}, ::OrdinaryDiffEqTsit5.Tsit5Cache{…}, ::Bool)
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [25] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:181 [inlined]
 [26] solve!
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:620 [inlined]
 [27] (::Nothing)(none::typeof(solve!), none::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ Reactant ./<missing>:0
 [28] call_with_reactant(::typeof(solve!), ::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [29] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:7 [inlined]
 [30] (::Nothing)(none::OrdinaryDiffEqCore.var"##__solve#62", none::@Kwargs{…}, none::typeof(SciMLBase.__solve), none::ODEProblem{…}, none::Tsit5{…}, none::Tuple{})
    @ Reactant ./<missing>:0
 [31] merge
    @ ./namedtuple.jl:349 [inlined]
 [32] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:6 [inlined]
 [33] call_with_reactant(::OrdinaryDiffEqCore.var"##__solve#62", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [34] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:1 [inlined]
 [35] #solve_call#36
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:657 [inlined]
 [36] solve_call
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:614 [inlined]
 [37] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{…}, none::typeof(DiffEqBase.solve_call), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [38]
call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [39] #solve_up#45
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1211 [inlined]
 [40] (::Nothing)(none::DiffEqBase.var"##solve_up#45", none::SciMLBase.ChainRulesOriginator, none::@Kwargs{…}, none::typeof(DiffEqBase.solve_up), none::ODEProblem{…}, none::Nothing, none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [41] call_with_reactant(::DiffEqBase.var"##solve_up#45", ::SciMLBase.ChainRulesOriginator, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Nothing, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [42] solve_up
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1188 [inlined]
 [43] #solve#43
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1083 [inlined]
 [44] (::Nothing)(none::DiffEqBase.var"##solve#43", none::Nothing, none::Nothing, none::Nothing, none::Val{…}, none::@Kwargs{…}, none::typeof(solve), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [45] call_with_reactant(::DiffEqBase.var"##solve#43", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [46] solve
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1073 [inlined]
 [47] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{…}, none::typeof(solve), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [48] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [49] make_mlir_fn(f::typeof(solve), args::Tuple{…}, kwargs::@NamedTuple{…}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:332
 [50] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{…}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:1528
 [51] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:1495 [inlined]
 [52] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3386
 [53] compile_xla
    @ ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3359 [inlined]
 [54] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3458
Some type information was truncated. Use `show(err)` to see complete types.

Kinda of trying Reactant a bit everywhere rn :')

You get a bit farther with

sol = @jit solve(prob, Tsit5(), dt=1e-2)

But not all the way:

ERROR: MethodError: objects of type Base.RefValue{typeof(DiffEqBase.ODE_DEFAULT_NORM)} are not callable
The object of type `Base.RefValue{typeof(DiffEqBase.ODE_DEFAULT_NORM)}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/Efoni/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::Base.RefValue{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:875
  [3] calculate_residuals
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:11 [inlined]
  [4] (::Nothing)(none::typeof(DiffEqBase.calculate_residuals), none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Reactant.TracedRNumber{…}, none::Base.RefValue{…}, none::Reactant.TracedRNumber{…})
    @ Reactant ./<missing>:0
  [5] calculate_residuals
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:11 [inlined]
  [6] call_with_reactant(::typeof(DiffEqBase.calculate_residuals), ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Reactant.TracedRNumber{…}, ::Base.RefValue{…}, ::Reactant.TracedRNumber{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
  [7] make_mlir_fn(f::typeof(DiffEqBase.calculate_residuals), args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Nothing, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:330
  [8] elem_apply(::Function, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Base.RefValue{…}, ::Reactant.TracedRArray{…})
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:1082
  [9] _copyto!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:779 [inlined]
 [10] (::Nothing)(none::typeof(Reactant.TracedRArrayOverrides._copyto!), none::Reactant.TracedRArray{…}, none::Base.Broadcast.Broadcasted{…})
    @ Reactant ./<missing>:0
 [11] getproperty
    @ ./Base.jl:49 [inlined]
 [12] size
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:489 [inlined]
 [13] axes
    @ ./abstractarray.jl:98 [inlined]
 [14] _copyto!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:772 [inlined]
 [15] call_with_reactant(::typeof(Reactant.TracedRArrayOverrides._copyto!), ::Reactant.TracedRArray{…}, ::Base.Broadcast.Broadcasted{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [16] materialize!
    @ ~/.julia/packages/Reactant/Efoni/src/TracedRArray.jl:730 [inlined]
 [17] materialize!
    @ ./broadcast.jl:880 [inlined]
 [18] fast_materialize!
    @ ~/.julia/packages/FastBroadcast/wfdTr/src/FastBroadcast.jl:279 [inlined]
 [19] calculate_residuals!
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/calculate_residuals.jl:97 [inlined]
 [20] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:224 [inlined]
 [21] (::Nothing)(none::typeof(OrdinaryDiffEqCore.perform_step!), none::OrdinaryDiffEqCore.ODEIntegrator{…}, none::OrdinaryDiffEqTsit5.Tsit5Cache{…}, none::Bool)
    @ Reactant ./<missing>:0
 [22] getproperty
    @ ~/.julia/packages/SciMLBase/DbVzk/src/integrator_interface.jl:487 [inlined]
 [23] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:181 [inlined]
 [24] call_with_reactant(::typeof(OrdinaryDiffEqCore.perform_step!), ::OrdinaryDiffEqCore.ODEIntegrator{…}, ::OrdinaryDiffEqTsit5.Tsit5Cache{…}, ::Bool)
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [25] perform_step!
    @ ~/.julia/packages/OrdinaryDiffEqTsit5/UmDPY/src/tsit_perform_step.jl:181 [inlined]
 [26] solve!
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:620 [inlined]
 [27] (::Nothing)(none::typeof(solve!), none::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ Reactant ./<missing>:0
 [28] call_with_reactant(::typeof(solve!), ::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [29] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:7 [inlined]
 [30] (::Nothing)(none::OrdinaryDiffEqCore.var"##__solve#62", none::@Kwargs{…}, none::typeof(SciMLBase.__solve), none::ODEProblem{…}, none::Tsit5{…}, none::Tuple{})
    @ Reactant ./<missing>:0
 [31] merge
    @ ./namedtuple.jl:349 [inlined]
 [32] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:6 [inlined]
 [33] call_with_reactant(::OrdinaryDiffEqCore.var"##__solve#62", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:0
 [34] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/zs1s7/src/solve.jl:1 [inlined]
 [35] #solve_call#36
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:657 [inlined]
 [36] solve_call
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:614 [inlined]
 [37] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{…}, none::typeof(DiffEqBase.solve_call), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [38] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [39] #solve_up#45
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1211 [inlined]
 [40] (::Nothing)(none::DiffEqBase.var"##solve_up#45", none::SciMLBase.ChainRulesOriginator, none::@Kwargs{…}, none::typeof(DiffEqBase.solve_up), none::ODEProblem{…}, none::Nothing, none::Reactant.TracedRArray{…}, none::Reactant.TracedRArray{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [41] call_with_reactant(::DiffEqBase.var"##solve_up#45", ::SciMLBase.ChainRulesOriginator, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Nothing, ::Reactant.TracedRArray{…}, ::Reactant.TracedRArray{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [42] solve_up
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1188 [inlined]
 [43] #solve#43
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1083 [inlined]
 [44] (::Nothing)(none::DiffEqBase.var"##solve#43", none::Nothing, none::Nothing, none::Nothing, none::Val{…}, none::@Kwargs{…}, none::typeof(solve), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [45] call_with_reactant(::DiffEqBase.var"##solve#43", ::Nothing, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [46] solve
    @ ~/.julia/packages/DiffEqBase/SwbtQ/src/solve.jl:1073 [inlined]
 [47] (::Nothing)(none::typeof(Core.kwcall), none::@NamedTuple{…}, none::typeof(solve), none::ODEProblem{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [48] call_with_reactant(::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Reactant ~/.julia/packages/Reactant/Efoni/src/utils.jl:501
 [49] make_mlir_fn(f::typeof(solve), args::Tuple{…}, kwargs::@NamedTuple{…}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/Efoni/src/TracedUtils.jl:332
 [50] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{…}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:1528
 [51] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:1495 [inlined]
 [52] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3386
 [53] compile_xla
    @ ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3359 [inlined]
 [54] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/Efoni/src/Compiler.jl:3458
Some type information was truncated. Use `show(err)` to see complete types.
1 Like

Ironically the error is not in tracing but our abstractinterpreter override (which we’re actually using to wean reactant off tracing).

If you’re able to can you file an issue with as minimal working an example as possible (ideally something like within that calculate_residuals per the stacktrace)

No issue with FastBroadcast, no issue with mulladd macro neither with fastmath, so it comes down to finding where internalnorm
in

@inline @muladd function calculate_residuals(ũ::Number, u₀::Number, u₁::Number,
        α, ρ, internalnorm, t)
    @fastmath ũ / (α + max(internalnorm(u₀, t), internalnorm(u₁, t)) * ρ)
end

comes from but it’s hard to find it back sorry :cry:
edit : seems like it default to ODE_DEFAULT_NORM which has tons of dispatch

ODE_DEFAULT_NORM(u::Union{AbstractFloat, Complex}, t) = @fastmath abs(u)

function ODE_DEFAULT_NORM(f::F, u::Union{AbstractFloat, Complex}, t) where {F}
    return @fastmath abs(f(u))
end

function ODE_DEFAULT_NORM(u::Array{T}, t) where {T <: Union{AbstractFloat, Complex}}
    x = zero(T)
    @inbounds @fastmath for ui in u
        x += abs2(ui)
    end
    Base.FastMath.sqrt_fast(real(x) / max(length(u), 1))
end

function ODE_DEFAULT_NORM(f::F,
        u::Union{Array{T}, Iterators.Zip{<:Tuple{Vararg{Array{T}}}}},
        t) where {F, T <: Union{AbstractFloat, Complex}}
    x = zero(T)
    @inbounds @fastmath for ui in u
        x += abs2(f(ui))
    end
    Base.FastMath.sqrt_fast(real(x) / max(length(u), 1))
end

function ODE_DEFAULT_NORM(u::StaticArraysCore.StaticArray{<:Tuple, T},
        t) where {T <: Union{AbstractFloat, Complex}}
    Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1))
end

function ODE_DEFAULT_NORM(f::F, u::StaticArraysCore.StaticArray{<:Tuple, T},
        t) where {F, T <: Union{AbstractFloat, Complex}}
    Base.FastMath.sqrt_fast(real(sum(abs2 ∘ f, u)) / max(length(u), 1))
end

function ODE_DEFAULT_NORM(
        u::Union{
            AbstractArray,
            RecursiveArrayTools.AbstractVectorOfArray
        },
        t)
    Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / max(recursive_length(u), 1))
end

function ODE_DEFAULT_NORM(f::F, u::AbstractArray, t) where {F}
    Base.FastMath.sqrt_fast(UNITLESS_ABS2(f, u) / max(recursive_length(u), 1))
end

ODE_DEFAULT_NORM(u, t) = norm(u)
ODE_DEFAULT_NORM(f::F, u, t) where {F} = norm(f.(u))

however the error message said it was a RefVal of its type that tried to call which may mean it didn’t made it to those dispatch ?

I’m trying to make a wrapper of a wrapper of a wrapper of an inlined function but Reactant is truly hard to break so it may not come from this how can I provoc the abtract interpreteur ?

you can call Reactant.call_with_reactant(f, args...) and it’ll execute the code in our interpreter