Lux + Enzyme naively applying model seems to cause runtime activity error

Consider the simple setup

using Lux, LinearAlgebra
import Optimization, Random, Enzyme, ComponentArrays

model = Chain(
    Dense(2 => 16, Lux.tanh),
    Dense(16, 1),
)

parameters, states = Lux.setup(Random.Xoshiro(42), model) |> f64
point = [0.0, 0.0]

adtype = AutoEnzyme()

function f(parameters, model, states, point)
    return only(first(model(point, parameters, states)))
end


print("Testing f: $(f(parameters, model, states, point))\n")
manual_gradient = Enzyme.gradient(Enzyme.Reverse, f, parameters, Enzyme.Const(model), Enzyme.Const(states),Enzyme.Const(point))
print("Manual gradient: $manual_gradient\n")

On Julia v1.11.6 with Enzyme v0.13.66 this raises

Testing f: 0.4700911923782194
ERROR: 
If you are using Enzyme by selecting the `AutoEnzyme` object from ADTypes, you may want to try setting the `mode` option as follows:

        AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))
        AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))

This hint appears because DifferentiationInterface and Enzyme are both loaded. It does not necessarily imply that Enzyme is being called through DifferentiationInterface.

Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for:   %.pn46 = phi {} addrspace(10)* [ %38, %L49 ], [ %27, %L42 ] const val:   %27 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %26, align 8, !dbg !188, !tbaa !190, !alias.scope !168, !noalias !169, !dereferenceable_or_null !193, !align !141, !enzyme_type !194, !enzymejl_source_type_Memory\7BFloat64\7D !0, !enzymejl_byref_MUT_REF !0
 value=Unknown object of type Memory{Float64}
 llvalue=  %27 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %26, align 8, !dbg !188, !tbaa !190, !alias.scope !168, !noalias !169, !dereferenceable_or_null !193, !align !141, !enzyme_type !194, !enzymejl_source_type_Memory\7BFloat64\7D !0, !enzymejl_byref_MUT_REF !0

Stacktrace:
 [1] reshape
   @ ./reshapedarray.jl:60
 [2] reshape
   @ ./reshapedarray.jl:129
 [3] reshape
   @ ./reshapedarray.jl:128
 [4] make_abstract_matrix
   @ ~/.julia/packages/Lux/H3WdN/src/utils.jl:204
 [5] Dense
   @ ~/.julia/packages/Lux/H3WdN/src/layers/basic.jl:343

Stacktrace:
  [1] reshape
    @ ./reshapedarray.jl:54 [inlined]
  [2] reshape
    @ ./reshapedarray.jl:129 [inlined]
  [3] reshape
    @ ./reshapedarray.jl:128 [inlined]
  [4] make_abstract_matrix
    @ ~/.julia/packages/Lux/H3WdN/src/utils.jl:204 [inlined]
  [5] Dense
    @ ~/.julia/packages/Lux/H3WdN/src/layers/basic.jl:343
  [6] apply
    @ ~/.julia/packages/LuxCore/XUV80/src/LuxCore.jl:155 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:0 [inlined]
  [8] applychain
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:482 [inlined]
  [9] Chain
    @ ~/.julia/packages/Lux/H3WdN/src/layers/containers.jl:480 [inlined]

Adding Const to f as well only makes the error log less readable, giving

ERROR: Enzyme.Compiler.EnzymeRuntimeActivityError(Cstring(0x00007fc704148521))

What’s the proper way to use Lux models in e.g. a loss function with Enzyme?

what happens if you set runtime activity on like the error message suggests?

manual_gradient = Enzyme.gradient(set_runtime_activity(Enzyme.Reverse), f, parameters, Enzyme.Const(model), Enzyme.Const(states),Enzyme.Const(point))

though that said I’d recommend using this inside a Reactant.@compile [which will both be faster and never need runtime activity]

cc @avikpal

1 Like

I had hoped to be able to specifically avoid setting runtime activity (the actual project is a lot more complicated and activating that runtime activity stuff seemed to slow everything down a whole lot). I will try using Reactant, but had planned to maybe use AMDGPU.jl later on, and AFAIU there is no compatibility with Reactant there?

There will be support for AMD GPUs in reactant, but we haven’t prioritized the build for that yet (in no small part because we don’t have an AMD to do the dev setup on). But hopefully that will start up some time in the fall

2 Likes

Compiling with Reactant works wonderfully, thank you.

For my use case though I want to use Optimization.jl’s solvers (in particular LBFGS), are those possible to use with Reactant? I assume I’d just want to provide a function to compute the gradient to OptimizationFunction, but I’m unsure of how to get ComponentArrays and Reactant to get to work together. Pushing a ComponentArray to reactant_device() seems to cause errors, and casting a Reactant array to a ComponentArray seems to make Enzyme gradients vanish. Is this interaction between Optimization, Lux, and Reactant something that should work?

On the other hand I still don’t really understand where the runtime activity comes from with just pure Enzyme. Is this something I should file an issue about?

do you have a sample code snippet? offhand I don’t see any reason there would be a problem.

I’m not exactly sure how the componentarray issue arises, but another option to try is Reactant.to_rarray(x) [which preserves aliasing and structure during the conversion].

cc @avikpal though for sure on this

Lux.jl/examples/HyperNet/main.jl at main · LuxDL/Lux.jl · GitHub this seems to showing ComponentArrays and Reactant working together, well, at least when its used inside a layer: Lux.jl/examples/HyperNet/main.jl at ae92687dbf837c02b08fbb5540a6816f38e2871f · LuxDL/Lux.jl · GitHub

1 Like

Yeah, thank you, I can’t seem to reproduce the issue with ComponentArray now anyway, works great with Reactant.

My current MWE is

using Lux, LinearAlgebra, Reactant, Enzyme, ComponentArrays
import Optimization, Random

model = Chain(
    Dense(2 => 16, Lux.tanh),
    Dense(16, 1),
)
const x_dev = reactant_device()
parameters, states = Lux.setup(Random.Xoshiro(42), model) |> f64
parameters = ComponentArray(parameters)
parameters_ra = parameters |> x_dev
states_ra = states |> x_dev


point = [0.0, 0.0] 
point_ra = point |> x_dev

function loss_function(parameters, full_data)
    model, states, point = full_data
    result, _ = model(point, parameters, states)
    return only(result)
end

full_data = (model, states, point)
full_data_ra = (model, states_ra, point_ra)

function enzyme_gradient(parameters, full_data)
    return gradient(Reverse, Const(loss_function), parameters, Const(full_data))
end

print("Testing loss: $(loss_function(parameters, full_data))\n") # this works

# commented out code below causes runtime activity
# print("Testing enzyme without Reactant: $(enzyme_gradient(parameters, full_data))\n") 


# below all works fine if loss_function returns result instead of only(result)
loss_function_compiled  = Reactant.@compile loss_function(parameters_ra, full_data_ra)
enzyme_gradient_compiled = Reactant.@compile enzyme_gradient(parameters_ra, full_data_ra)
enzyme_gradient_result = enzyme_gradient_compiled(parameters_ra, full_data_ra)
print("Loss function + Reactant result: $(loss_function_compiled(parameters_ra, full_data_ra))\n")
print("Reactant+Enzyme gradient: $enzyme_gradient_result\n")


f = Optimization.OptimizationFunction{false}(loss_function_compiled; grad=enzyme_gradient_compiled)
prob = Optimization.OptimizationProblem(f, parameters_ra, full_data_ra)
result = Optimization.solve(prob, Optimization.LBFGS())

Currently I’m struggling with Reactant compilation if I want loss_function to actually return a scalar, doing only(result) (or first(result)) raises a method ambiguity error, am I doing something evil by trying to index here or is this a bug? I need loss_function to return a scalar for Optimization.OptimizationFunction’s isinplace detection to work properly, I think.

Full error:

ERROR: MethodError: getindex(::Base.ReshapedArray{Reactant.TracedRNumber{Float64}, 1, Reactant.TracedRArray{Float64, 2}, Tuple{}}, ::Int64) is ambiguous.

Candidates:
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices::Union{Int64, Reactant.TracedRNumber{Int64}}...)
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:289
  getindex(A::Base.ReshapedArray{T, N, P, Tuple{}} where {T, N, P<:AbstractArray}, index::Int64)
    @ Base reshapedarray.jl:253
  getindex(a::Base.ReshapedArray{Reactant.TracedRNumber{T}} where T, indices...)
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:293
  getindex(A::Base.ReshapedArray{T, N}, indices::Vararg{Int64, N}) where {T, N}
    @ Base reshapedarray.jl:259
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, N, <:Any, <:Any, <:Src}} where {N, Src, Dst}, index::Union{Int64, Reactant.TracedRNumber{Int64}}...) where T
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:264
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, 1, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, 1, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, 1, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, 1, <:Any, <:Any, <:Src}} where {Src, Dst}, indices) where T
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:274
  getindex(a::Union{Base.LogicalIndex{Reactant.TracedRNumber{T}, <:Src}, Base.ReinterpretArray{Reactant.TracedRNumber{T}, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:Src, Base.ReshapedArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:Src, SubArray{Reactant.TracedRNumber{T}, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:Src, Adjoint{Reactant.TracedRNumber{T}, <:Dst}, Diagonal{Reactant.TracedRNumber{T}, <:Dst}, LowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, Symmetric{Reactant.TracedRNumber{T}, <:Dst}, Transpose{Reactant.TracedRNumber{T}, <:Dst}, Tridiagonal{Reactant.TracedRNumber{T}, <:Dst}, UnitLowerTriangular{Reactant.TracedRNumber{T}, <:Dst}, UnitUpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, UpperTriangular{Reactant.TracedRNumber{T}, <:Dst}, PermutedDimsArray{Reactant.TracedRNumber{T}, N, <:Any, <:Any, <:Src}} where {Src, Dst}, indices::Vararg{Any, N}) where {T, N}
    @ Reactant.TracedRArrayOverrides ~/.julia/packages/Reactant/doj2y/src/TracedRArray.jl:277

Possible fix, define
  getindex(::Base.ReshapedArray{Reactant.TracedRNumber{T}, 1, P, Tuple{}} where P<:AbstractArray, ::Int64) where T

Nevermind, just doing sum(result) instead of only(result) fixes this.

However constructing the OptimizationFunction remains problematic. I still get

ERROR: type UnionAll has no field parameters
Stacktrace:
 [1] getproperty
   @ ./Base.jl:43 [inlined]
 [2] isinplace(f::Reactant.Compiler.Thunk{…}, inplace_param_number::Int64, fname::String, iip_preferred::Bool; has_two_dispatches::Bool, isoptimization::Bool, outofplace_param_number::Int64)
   @ SciMLBase ~/.julia/packages/SciMLBase/wfZCo/src/utils.jl:290
 [3] isinplace (repeats 2 times)
   @ ~/.julia/packages/SciMLBase/wfZCo/src/utils.jl:246 [inlined]
 [4] #_#155
   @ ~/.julia/packages/SciMLBase/wfZCo/src/scimlfunctions.jl:4249 [inlined]
 [5] OptimizationFunction
   @ ~/.julia/packages/SciMLBase/wfZCo/src/scimlfunctions.jl:4226 [inlined]
 [6] top-level scope
   @ ~/OneDrive/Education/Studium/Mathematik/ScientificComputing/RandomizedPositiveMass/src/MWE.jl:44
Some type information was truncated. Use `show(err)` to see complete types.

Is there some other (better) way to get an OptimizationFunction that uses compiled Reactant functions?

I can avoid this error if I do something like Optimization.OptimizationFunction(loss_function; grad=enzyme_gradient_compiled), but having to mix compiled and non-compiled functions here seems weird. And even then, I quickly get errors related to the gradient being called in unexpected ways (this is with Optimization.Sophia, Optimization.LBFGS fails even earlier by trying to convert a ConcretePJRTArray to a pointer, I tried a bunch of other optimisers from Optim and Optimisers as well, but got a variety of different errors).

honestly that looks like a bug in scimlbase.isinplace. cc @ChrisRackauckas

We might need to specialize it for Reactant. But I assume you already worked around it by just setting iip? OptimizationProblem{true}(…)

I tried to set OptimizationFunction{false}, but that didn’t prevent the check. As far as I can see the constructor for OptimizationFunction always tests the passed function.

But in general OptimizationProblems and Reactantcompiled methods should interop nicely? I couldn’t find anything on that front, only saw the tangentially related Add Reactant support · Issue #969 · SciML/Optimization.jl · GitHub.

I do get some weird errors when I try to forge ahead with OptimizationOptimJL.LBFGS instead of Optimization.LBFGS (this is without the loss_function itself being compiled since I couldn’t resolve the inplace issue):

LoadError: MethodError: no method matching ConcretePJRTArray(::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}})
The type `ConcretePJRTArray` exists, but no method is defined for this combination of argument types when trying to construct it.
Rest of log

Closest candidates are:
  ConcretePJRTArray(::Array{T, N}; client, idx, device, sharding) where {T, N}
   @ Reactant ~/.julia/packages/Reactant/doj2y/src/Types.jl:181
  ConcretePJRTArray(::Number; kwargs...)
   @ Reactant deprecated.jl:103
  ConcretePJRTArray(::Union{ConcretePJRTArray{T, N, D, S}, Base.LogicalIndex{T, <:ConcretePJRTArray}, Base.ReinterpretArray{T, N, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s14"}, var"#s14"}} where var"#s14"<:ConcretePJRTArray, Base.ReshapedArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}}, SubArray{<:Any, <:Any, var"#s15"}, var"#s15"}} where var"#s15"<:ConcretePJRTArray, SubArray{T, N, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, Base.ReshapedArray{<:Any, <:Any, <:Union{Base.ReinterpretArray{<:Any, <:Any, <:Any, <:Union{SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, SubArray{<:Any, <:Any, var"#s16"}, var"#s16"}}, var"#s16"}} where var"#s16"<:ConcretePJRTArray, Adjoint{T, <:ConcretePJRTArray{T, N, D, S}}, Diagonal{T, <:ConcretePJRTArray{T, N, D, S}}, LowerTriangular{T, <:ConcretePJRTArray{T, N, D, S}}, Symmetric{T, <:ConcretePJRTArray{T, N, D, S}}, Transpose{T, <:ConcretePJRTArray{T, N, D, S}}, Tridiagonal{T, <:ConcretePJRTArray{T, N, D, S}}, UnitLowerTriangular{T, <:ConcretePJRTArray{T, N, D, S}}, UnitUpperTriangular{T, <:ConcretePJRTArray{T, N, D, S}}, UpperTriangular{T, <:ConcretePJRTArray{T, N, D, S}}, PermutedDimsArray{T, N, <:Any, <:Any, <:ConcretePJRTArray}} where {T, N, D, S}; kwargs...)
   @ Reactant ~/.julia/packages/Reactant/doj2y/src/Types.jl:230
Stacktrace:
  [1] copy(bc::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayStyle{ConcretePJRTArray}, Tuple{ComponentArrays.CombinedAxis{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}, Base.OneTo{Int64}}}, Type{Float64}, Tuple{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}}})
    @ Reactant ~/.julia/packages/Reactant/doj2y/src/ConcreteRArray.jl:435
  [2] materialize
    @ ./broadcast.jl:872 [inlined]
  [3] x_of_nans(x::ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tf::Type{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/n7XXO/src/NLSolversBase.jl:78
  [4] alloc_DF(x::ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, F::Float64)
    @ NLSolversBase ~/.julia/packages/NLSolversBase/n7XXO/src/objective_types/abstract.jl:22
  [5] __solve(cache::OptimizationCache{OptimizationFunction{true, SciMLBase.NoAD, typeof(loss_function), OptimizationBase.var"#grad#204"{Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, OptimizationFunction{true, SciMLBase.NoAD, typeof(loss_function), Reactant.Compiler.Thunk{typeof(enzyme_gradient), Symbol("##enzyme_gradient_reactant#366"), false, Tuple{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationBase.ReInitCache{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, Nothing, Nothing, Nothing, Nothing, Nothing, LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Returns{Nothing}}, Bool, OptimizationOptimJL.var"#4#6", Nothing})
    @ OptimizationOptimJL ~/.julia/packages/OptimizationOptimJL/VaURt/src/OptimizationOptimJL.jl:200
  [6] solve!(cache::OptimizationCache{OptimizationFunction{true, SciMLBase.NoAD, typeof(loss_function), OptimizationBase.var"#grad#204"{Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, OptimizationFunction{true, SciMLBase.NoAD, typeof(loss_function), Reactant.Compiler.Thunk{typeof(enzyme_gradient), Symbol("##enzyme_gradient_reactant#366"), false, Tuple{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationBase.ReInitCache{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, Nothing, Nothing, Nothing, Nothing, Nothing, LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Returns{Nothing}}, Bool, OptimizationOptimJL.var"#4#6", Nothing})
    @ SciMLBase ~/.julia/packages/SciMLBase/wfZCo/src/solve.jl:226
  [7] solve(::OptimizationProblem{true, OptimizationFunction{true, SciMLBase.NoAD, typeof(loss_function), Reactant.Compiler.Thunk{typeof(enzyme_gradient), Symbol("##enzyme_gradient_reactant#366"), false, Tuple{ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}}, Reactant.XLA.PJRT.LoadedExecutable, Reactant.XLA.PJRT.Device, Reactant.XLA.PJRT.Client, Tuple{}, Vector{Bool}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float64, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{Axis{(layer_1 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_2 = ViewAxis(49:65, Axis(weight = ViewAxis(1:16, ShapedAxis((1, 16))), bias = ViewAxis(17:17, Shaped1DAxis((1,))))))}}}, Tuple{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Returns{Nothing}}; kwargs::@Kwargs{maxiters::Int64})
    @ SciMLBase ~/.julia/packages/SciMLBase/wfZCo/src/solve.jl:128...

I think this was the reason I initially thought ComponentArrays didn’t work with Reactant, is this an actual compatibility issue between those or some weirdness in the other involved packages? to_rarrray and ... |> reactant_device() both still work.

In general I seem to quickly encounter lots of issues with Reactant, is there really no way currently to use Lux+Enzyme+Optimization (or some other library providing LBFGS) without Reactant? Should I just rewrite my code to not use mutation or StaticArrays so that I can use Zygote?

We need to make a SciML LBFGS that matches our normal type support. None of our options (Optim.jl, NLopt.jl, L-BFGS-B (Fortran)) can support the kind of things SciML does in general right now and has limitations at the edges. SimpleOptimization.jl’s LBFGS should be okay, but we need to finish and document it.

2 Likes

hm that’s weird, can you open an issue on reactant.jl with a reproducer for that error?

Ah I did not know about SimpleOptimization. Thank you, I’ll try that

I can only reproduce it by using the (apparently broken) LBFGS. Should I still open an issue? I can manually fix the issue via type piracy:

function Reactant.ConcreteRArray(v::ComponentArray)
    return Reactant.ConcreteRArray(Array(v))
end

and then LBFGS just continues hitting other issues, so I don’t really know how relevant this is.

Btw, is compatibility of Reactant with StaticArrays something that should work or is planned? I didn’t see an issue about that. For the structures involved in my problem just setting

function Reactant.ConcreteRArray(v::StaticArray)
    return Reactant.ConcreteRArray(Array(v))
end

was enough, though that of course discards the size information, don’t know if that is avoidable.

yeah go ahead and open an issue (ideally with reproducers that require those piracy pieces above). it may make sense to vendor them (my guess is those libraries assumed conversion into a base array only, not a possible reactant array)

What’s SimpleOptimization.jl? I couldn’t find it anywhere.

at some point we discussed about problem with Optimization.LBFGS (refresher), and I think you said you would pull it out of Optimization – what’s the plan of that and how is this SimpleOptimization related?

It’s here, seems to not yet be in the main Julia package registry.

1 Like