Lux + Enzyme naively applying model seems to cause runtime activity error

Consider the simple setup

using Lux, LinearAlgebra
import Optimization, Random, Enzyme, ComponentArrays

model = Chain(
    Dense(2 => 16, Lux.tanh),
    Dense(16, 1),
)

parameters, states = Lux.setup(Random.Xoshiro(42), model) |> f64
point = [0.0, 0.0]

adtype = AutoEnzyme()

function f(parameters, model, states, point)
    return only(first(model(point, parameters, states)))
end


print("Testing f: $(f(parameters, model, states, point))\n")
manual_gradient = Enzyme.gradient(Enzyme.Reverse, f, parameters, Enzyme.Const(model), Enzyme.Const(states),Enzyme.Const(point))
print("Manual gradient: $manual_gradient\n")

On Julia v1.11.6 with Enzyme v0.13.66 this raises

Testing f: 0.4700911923782194
ERROR: 
If you are using Enzyme by selecting the `AutoEnzyme` object from ADTypes, you may want to try setting the `mode` option as follows:

        AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))
        AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))

This hint appears because DifferentiationInterface and Enzyme are both loaded. It does not necessarily imply that Enzyme is being called through DifferentiationInterface.

Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   %.pn46 = phi {} addrspace(10)* [ %38, %L49 ], [ %27, %L42 ] const val:   %27 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %26, align 8, !dbg !188, !tbaa !190, !alias.scope !168, !noalias !169, !dereferenceable_or_null !193, !align !141, !enzyme_type !194, !enzymejl_source_type_Memory\7BFloat64\7D !0, !enzymejl_byref_MUT_REF !0
 value=Unknown object of type Memory{Float64}
 llvalue=  %27 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %26, align 8, !dbg !188, !tbaa !190, !alias.scope !168, !noalias !169, !dereferenceable_or_null !193, !align !141, !enzyme_type !194, !enzymejl_source_type_Memory\7BFloat64\7D !0, !enzymejl_byref_MUT_REF !0

Stacktrace:
 [1] reshape
   @ ./reshapedarray.jl:60
 [2] reshape
   @ ./reshapedarray.jl:129
 [3] reshape
   @ ./reshapedarray.jl:128
 [4] make_abstract_matrix
   @ ~/.julia/packages/Lux/H3WdN/src/utils.jl:204
 [5] Dense
   @ ~/.julia/packages/Lux/H3WdN/src/layers/basic.jl:343

Stacktrace:
  [1] reshape
    @ ./reshapedarray.jl:54 [inlined]
  [2] reshape
    @ ./reshapedarray.jl:129 [inlined]
  [3] reshape
    @ ./reshapedarray.jl:128 [inlined]
  [4] make_abstract_matrix
    @ ~/.julia/packages/Lux/H3WdN/src/utils.jl:204 [inlined]
  [5] Dense
    @ ~/.julia/packages/Lux/H3WdN/src/layers/basic.jl:343
  [6] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:0 [inlined]
  [8] applychain
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:482 [inlined]
  [9] Chain
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:480 [inlined]

Adding Const to f as well only makes the error log less readable, giving

ERROR: Enzyme.Compiler.EnzymeRuntimeActivityError(Cstring(0x00007fc704148521))

What’s the proper way to use Lux models in e.g. a loss function with Enzyme?

what happens if you set runtime activity on like the error message suggests?

manual_gradient = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), f, parameters, Enzyme.Const(model), Enzyme.Const(states),Enzyme.Const(point))

though that said I’d recommend using this inside a Reactant.@compile [which will both be faster and never need runtime activity]

cc @avikpal

1 Like

I had hoped to be able to specifically avoid setting runtime activity (the actual project is a lot more complicated and activating that runtime activity stuff seemed to slow everything down a whole lot). I will try using Reactant, but had planned to maybe use AMDGPU.jl later on, and AFAIU there is no compatibility with Reactant there?

There will be support for AMD GPUs in reactant, but we haven’t prioritized the build for that yet (in no small part because we don’t have an AMD to do the dev setup on). But hopefully that will start up some time in the fall

2 Likes

Compiling with Reactant works wonderfully, thank you.

For my use case though I want to use Optimization.jl’s solvers (in particular LBFGS), are those possible to use with Reactant? I assume I’d just want to provide a function to compute the gradient to OptimizationFunction, but I’m unsure of how to get ComponentArrays and Reactant to get to work together. Pushing a ComponentArray to reactant_device() seems to cause errors, and casting a Reactant array to a ComponentArray seems to make Enzyme gradients vanish. Is this interaction between Optimization, Lux, and Reactant something that should work?

On the other hand I still don’t really understand where the runtime activity comes from with just pure Enzyme. Is this something I should file an issue about?

do you have a sample code snippet? offhand I don’t see any reason there would be a problem.

I’m not exactly sure how the componentarray issue arises, but another option to try is Reactant.to_rarray(x) [which preserves aliasing and structure during the conversion].

cc @avikpal though for sure on this

Lux.jl/examples/HyperNet/main.jl at main · LuxDL/Lux.jl · GitHub this seems to showing ComponentArrays and Reactant working together, well, at least when its used inside a layer: Lux.jl/examples/HyperNet/main.jl at ae92687dbf837c02b08fbb5540a6816f38e2871f · LuxDL/Lux.jl · GitHub

1 Like

Yeah, thank you, I can’t seem to reproduce the issue with ComponentArray now anyway, works great with Reactant.

My current MWE is

using Lux, LinearAlgebra, Reactant, Enzyme, ComponentArrays
import Optimization, Random

model = Chain(
    Dense(2 => 16, Lux.tanh),
    Dense(16, 1),
)
const x_dev = reactant_device()
parameters, states = Lux.setup(Random.Xoshiro(42), model) |> f64
parameters = ComponentArray(parameters)
parameters_ra = parameters |> x_dev
states_ra = states |> x_dev


point = [0.0, 0.0] 
point_ra = point |> x_dev

function loss_function(parameters, full_data)
    model, states, point = full_data
    result, _ = model(point, parameters, states)
    return only(result)
end

full_data = (model, states, point)
full_data_ra = (model, states_ra, point_ra)

function enzyme_gradient(parameters, full_data)
    return gradient(Reverse, Const(loss_function), parameters, Const(full_data))
end

print("Testing loss: $(loss_function(parameters, full_data))\n") # this works

# commented out code below causes runtime activity
# print("Testing enzyme without Reactant: $(enzyme_gradient(parameters, full_data))\n") 


# below all works fine if loss_function returns result instead of only(result)
loss_function_compiled  = Reactant.@compile loss_function(parameters_ra, full_data_ra)
enzyme_gradient_compiled = Reactant.@compile enzyme_gradient(parameters_ra, full_data_ra)
enzyme_gradient_result = enzyme_gradient_compiled(parameters_ra, full_data_ra)
print("Loss function + Reactant result: $(loss_function_compiled(parameters_ra, full_data_ra))\n")
print("Reactant+Enzyme gradient: $enzyme_gradient_result\n")


f = Optimization.OptimizationFunction{false}(loss_function_compiled; grad=enzyme_gradient_compiled)
prob = Optimization.OptimizationProblem(f, parameters_ra, full_data_ra)
result = Optimization.solve(prob, Optimization.LBFGS())

Currently I’m struggling with Reactant compilation if I want loss_function to actually return a scalar, doing only(result) (or first(result)) raises a method ambiguity error, am I doing something evil by trying to index here or is this a bug? I need loss_function to return a scalar for Optimization.OptimizationFunction’s isinplace detection to work properly, I think.

Full error:

ERROR: MethodError: getindex(::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 1, Reactant.TracedRArray{Float64, 2}, Tuple{}}, ::Int64) is ambiguous.

Candidates:
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices::Union{Int64, Reactant.TracedRNumber{Int64}}...)
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:289
  getindex(A::Base.ReshapedArray{T, N, P, Tuple{}} where {T, N, P<:AbstractArray}, index::Int64)
    @ Base reshapedarray.jl:253
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices...)
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:293
  getindex(A::Base.ReshapedArray{T, N}, indices::Vararg{Int64, N}) where {T, N}
    @ Base reshapedarray.jl:259
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, N, <:Any, <:Any, <:Src}} where {N, Src, Dst}, index::Union{Int64, Reactant.TracedRNumber{Int64}}...) where T
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:264
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, 1, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, 1, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, 1, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, 1, <:Any, <:Any, <:Src}} where {Src, Dst}, indices) where T
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:274
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, N, <:Any, <:Any, <:Src}} where {Src, Dst}, indices::Vararg{Any, N}) where {T, N}
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:277

Possible fix, define
  getindex(::Base.ReshapedArray{Reactant.TracedRNumber{T}, 1, P, Tuple{}} where P<:AbstractArray, ::Int64) where T

Nevermind, just doing sum(result) instead of only(result) fixes this.

However constructing the OptimizationFunction remains problematic. I still get

ERROR: type UnionAll has no field parameters
Stacktrace:
 [1] getproperty
   @ ./Base.jl:43 [inlined]
 [2] isinplace(f::Reactant.Compiler.Thunk{…}, inplace_param_number::Int64, fname::String, iip_preferred::Bool; has_two_dispatches::Bool, isoptimization::Bool, outofplace_param_number::Int64)
   @ SciMLBase ~/.julia/packages/SciMLBase/wfZCo/src/utils.jl:290
 [3] isinplace (repeats 2 times)
   @ ~/.julia/packages/SciMLBase/wfZCo/src/utils.jl:246 [inlined]
 [4] #_#155
   @ ~/.julia/packages/SciMLBase/wfZCo/src/scimlfunctions.jl:4249 [inlined]
 [5] OptimizationFunction
   @ ~/.julia/packages/SciMLBase/wfZCo/src/scimlfunctions.jl:4226 [inlined]
 [6] top-level scope
   @ ~/OneDrive/Education/Studium/Mathematik/ScientificComputing/RandomizedPositiveMass/src/MWE.jl:44
Some type information was truncated. Use `show(err)` to see complete types.

Is there some other (better) way to get an OptimizationFunction that uses compiled Reactant functions?

I can avoid this error if I do something like Optimization.OptimizationFunction(loss_function; grad=enzyme_gradient_compiled), but having to mix compiled and non-compiled functions here seems weird. And even then, I quickly get errors related to the gradient being called in unexpected ways (this is with Optimization.Sophia, Optimization.LBFGS fails even earlier by trying to convert a ConcretePJRTArray to a pointer, I tried a bunch of other optimisers from Optim and Optimisers as well, but got a variety of different errors).

honestly that looks like a bug in scimlbase.isinplace. cc @ChrisRackauckas

We might need to specialize it for Reactant. But I assume you already worked around it by just setting iip? OptimizationProblem{true}(…)

I tried to set OptimizationFunction{false}, but that didn’t prevent the check. As far as I can see the constructor for OptimizationFunction always tests the passed function.

But in general OptimizationProblems and Reactantcompiled methods should interop nicely? I couldn’t find anything on that front, only saw the tangentially related Add Reactant support · Issue #969 · SciML/Optimization.jl · GitHub.