Memory allocations in Flux evaluation/training (w/MVE)

This really depends on the code, the device, the version of Reactant, of Enzyme ect ect its just not the julia code thats actually being run anymore after Reactant compilation but yes EnzymeJAX (whats behind Enzyme gradient when within a Reactant compilation) differentiate at the MLIR level which can’t really be beaten at this point.
Btw, if you wonder what’s actually being run you can use @code_hlo instead of @compile and it will show you the hlo code, if you need deeper code with info on devices use @code_xla

@yolhan_mannes minor nit, but I’d use the term here EnzymeMLIR rather than EnzymeJaX (lest we confuse people since this isn’t jax).

And that said it is very possible the CPU code without derivative is sped up as well, thanks to all the linear algebra and other optimizations (but of course this will depend on the code).

Re windows and GPU, yeah we’re aware of this and it is something we should try to address. WSL does give native GPU access on windows for now though. Honestly the biggest bottleneck (besides simply there being too many fun things we want to add), is that most reactant devs use Linux/MacOS, so both the intrinsic motivation as well as dev resources aren’t as compelling.

So I’ve been using WSL2 to run my code in Pluto (thanks for the direction), but I’m getting a weird error when I set the backend to GPU and try to move the model to RArrays (Reactant.to_rarray). The error doesn’t seem to prevent the code from running and still returns a model with Reactant arrays, but implies my model will be slow to execute. Is this something I should pay attention to, or might it be a bug? I’m also not sure what the issue would be since I’ve run the model in vanilla Flux on GPU without this problem. The error output is what happens when I run Dₙ|>Reactant.to_rarray.

Error message from GPUArraysCore
Failed to show value: 

Scalar indexing is disallowed.

Invocation of getindex(::ConcretePJRTArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.

This is typically caused by calling an iterating implementation of a method.

Such implementations *do not* execute on the GPU, but very slowly on the CPU,

and therefore should be avoided.



If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`

to enable scalar iteration globally or for the operations in question.

Stack trace
Here is what happened, the most recent locations are first:

error(s::String)
from 
julia → error.jl:44
 
errorscalar(op::String)
from 
GPUArraysCore → GPUArraysCore.jl:151
_assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
from 
GPUArraysCore → GPUArraysCore.jl:124
assertscalar(op::String)
from 
GPUArraysCore → GPUArraysCore.jl:112
getindex(::Reactant.ConcretePJRTArray{…}, ::Int64, ::Int64, ::Int64, ::Int64) ...show types...
from 
ConcreteRArray.jl:389
_getindex
from 
abstractarray.jl:1388
getindex
from 
abstractarray.jl:1342
iterate
from 
abstractarray.jl:1235
iterate
from 
abstractarray.jl:1233
_any(f::ComposedFunction{…}, itr::Reactant.ConcretePJRTArray{…}, ::Colon) ...show types...
from 
anyall.jl:123
any
from 
reducedim.jl:989
_any
from 
show.jl:170
#_any##0(x::Reactant.ConcretePJRTArray{…}) ...show types...
from 
show.jl:172
_any(f::Flux.var"#_any##0#_any##1"{…}, itr::Vector{…}, ::Colon) ...show types...
from 
anyall.jl:124
#any#756(f::Function, a::Vector{…}; dims::Function) ...show types...
from 
reducedim.jl:989
any
from 
reducedim.jl:989
_any
from 
show.jl:172
_all(f::Function, xs::Vector{…}) ...show types...
from 
show.jl:176
_nan_show(io::IO, x::Any)
from 
show.jl:159
_layer_show(io::IO, layer::Any, indent::Int64, name::Any)
from 
show.jl:116
_big_show(io::IO, obj::Any, indent::Int64, name::Any)
from 
show.jl:27
_big_show(io::IO, obj::Any, indent::Int64)
from 
show.jl:23
_big_show(io::IO, obj::Any, indent::Int64, name::Any)
from 
show.jl:43
_big_show(io::IO, obj::Any, indent::Int64)
from 
show.jl:23
_big_show(io::IO, obj::Any, indent::Int64, name::Any)
from 
show.jl:43
_big_show
from 
show.jl:23
show(io::IOContext{…}, m::MIME{…}, x::Flux.Chain{…}) ...show types...
from 
show.jl:9
show_richest(io::IOContext{…}, x::Any) ...show types...
from 
mime dance.jl:103
show_richest_withreturned
from 
mime dance.jl:17
format_output_default(val::Any, context::Any)
from 
format_output.jl:87
format_output
from 
format_output.jl:104
anonymous function
from 
format_output.jl:53
with_auto_id_counter
from 
auto_id.jl:24
anonymous function
from 
format_output.jl:52
#with_io_to_logs#133(f::PlutoRunner.var"#56#57"{…}; enabled::Bool, loglevel::Base.CoreLogging.LogLevel) ...show types...
from 
stdout.jl:64
with_io_to_logs
from 
stdout.jl:11
anonymous function
from 
logging.jl:133
with_logstate(f::PlutoRunner.var"#131#132"{…}, logstate::Base.CoreLogging.LogState) ...show types...
from 
logging.jl:542
with_logger(f::Function, logger::PlutoRunner.PlutoCellLogger)
from 
logging.jl:653
#with_logger_and_io_to_logs#129
from 
logging.jl:132
with_logger_and_io_to_logs
from 
logging.jl:131
formatted_result_of(notebook_id::Base.UUID, cell_id::Base.UUID, ends_with_semicolon::Bool, known_published_objects::Vector{String}, showmore::Nothing, workspace::Module; capture_stdout::Bool)
from 
format_output.jl:44
from 
WorkspaceManager.jl:603
eval(m::Module, e::Any)
from 
julia → boot.jl:489
 
#handle##0() ...show types...
from 
worker.jl:120

Indeed little weird here is an mwe :

using Flux
using Reactant
using Enzyme

Reactant.allowscalar(false)
dev = Reactant.to_rarray;

model = Chain(
        Dense(1=>32,tanh),
        Dense(32=>32,tanh),
        Dense(32=>32,tanh),
        Dense(32=>1)
    ) |> dev


x = rand(Float32,1,1000) |> dev
y = rand(Float32,1,1000) |> dev

loss(model,x,y) = sum(abs2,model(x) .- y)
loss_grad(model,x,y) = Enzyme.gradient(Enzyme.Reverse,loss,model,Const(x),Const(y)) |> first
# compile
modelc = @compile model(x)
lossc = @compile loss(model,x,y)
loss_gradc = @compile loss_grad(model,x,y)

# run
modelc(x)
lossc(model,x,y)
loss_gradc(model,x,y)

@code_hlo loss_grad(model,x,y)

comments :

  • the model construction leads to a scalar indexing
  • the gradient calculation (loss_gradc(model,x,y)) also gives an scalar indexing. However using allowscalar gives an hlo with no scalar indexing in the gradient calculation
julia> @code_hlo loss_grad(model,x,y)
module @reactant_loss_grad attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<1x32xf32> {enzymexla.memory_effects = []}, %arg1: tensor<32xf32> {enzymexla.memory_effects = []}, %arg2: tensor<32x32xf32> {enzymexla.memory_effects = []}, %arg3: tensor<32xf32> {enzymexla.memory_effects = []}, %arg4: tensor<32x32xf32> {enzymexla.memory_effects = []}, %arg5: tensor<32xf32> {enzymexla.memory_effects = []}, %arg6: tensor<32x1xf32> {enzymexla.memory_effects = []}, %arg7: tensor<1xf32> {enzymexla.memory_effects = []}, %arg8: tensor<1000x1xf32> {enzymexla.memory_effects = []}, %arg9: tensor<1000x1xf32> {enzymexla.memory_effects = []}) -> (tensor<1x32xf32>, tensor<32xf32>, tensor<32x32xf32>, tensor<32xf32>, tensor<32x32xf32>, tensor<32xf32>, tensor<32x1xf32>, tensor<1xf32>) attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<32x1000xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.reshape %arg9 : (tensor<1000x1xf32>) -> tensor<1x1000xf32>
    %1 = stablehlo.dot_general %arg0, %arg8, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<1x32xf32>, tensor<1000x1xf32>) -> tensor<32x1000xf32>
    %2 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<32xf32>) -> tensor<32x1000xf32>
    %3 = stablehlo.add %1, %2 : tensor<32x1000xf32>
    %4 = stablehlo.tanh %3 : tensor<32x1000xf32>
    %5 = stablehlo.dot_general %arg2, %4, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x1000xf32>) -> tensor<32x1000xf32>
    %6 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<32xf32>) -> tensor<32x1000xf32>
    %7 = stablehlo.add %5, %6 : tensor<32x1000xf32>
    %8 = stablehlo.tanh %7 : tensor<32x1000xf32>
    %9 = stablehlo.dot_general %arg4, %8, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x1000xf32>) -> tensor<32x1000xf32>
    %10 = stablehlo.broadcast_in_dim %arg5, dims = [0] : (tensor<32xf32>) -> tensor<32x1000xf32>
    %11 = stablehlo.add %9, %10 : tensor<32x1000xf32>
    %12 = stablehlo.tanh %11 : tensor<32x1000xf32>
    %13 = stablehlo.dot_general %arg6, %12, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x1xf32>, tensor<32x1000xf32>) -> tensor<1x1000xf32>
    %14 = stablehlo.broadcast_in_dim %arg7, dims = [0] : (tensor<1xf32>) -> tensor<1x1000xf32>
    %15 = stablehlo.add %13, %14 : tensor<1x1000xf32>
    %16 = stablehlo.subtract %15, %0 : tensor<1x1000xf32>
    %17 = stablehlo.add %16, %16 : tensor<1x1000xf32>
    %18 = stablehlo.reduce(%17 init: %cst_0) applies stablehlo.add across dimensions = [1] : (tensor<1x1000xf32>, tensor<f32>) -> tensor<1xf32>
    %19 = stablehlo.dot_general %17, %12, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<1x1000xf32>, tensor<32x1000xf32>) -> tensor<1x32xf32>
    %20 = stablehlo.reshape %19 : (tensor<1x32xf32>) -> tensor<32x1xf32>
    %21 = stablehlo.dot_general %arg6, %17, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x1xf32>, tensor<1x1000xf32>) -> tensor<32x1000xf32>
    %22 = stablehlo.multiply %12, %12 : tensor<32x1000xf32>
    %23 = stablehlo.subtract %cst, %22 : tensor<32x1000xf32>
    %24 = stablehlo.multiply %21, %23 : tensor<32x1000xf32>
    %25 = stablehlo.reduce(%24 init: %cst_0) applies stablehlo.add across dimensions = [1] : (tensor<32x1000xf32>, tensor<f32>) -> tensor<32xf32>
    %26 = stablehlo.dot_general %8, %24, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x1000xf32>, tensor<32x1000xf32>) -> tensor<32x32xf32>
    %27 = stablehlo.dot_general %arg4, %24, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x1000xf32>) -> tensor<32x1000xf32>
    %28 = stablehlo.multiply %8, %8 : tensor<32x1000xf32>
    %29 = stablehlo.subtract %cst, %28 : tensor<32x1000xf32>
    %30 = stablehlo.multiply %27, %29 : tensor<32x1000xf32>
    %31 = stablehlo.reduce(%30 init: %cst_0) applies stablehlo.add across dimensions = [1] : (tensor<32x1000xf32>, tensor<f32>) -> tensor<32xf32>
    %32 = stablehlo.dot_general %4, %30, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x1000xf32>, tensor<32x1000xf32>) -> tensor<32x32xf32>
    %33 = stablehlo.dot_general %arg2, %30, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x1000xf32>) -> tensor<32x1000xf32>
    %34 = stablehlo.multiply %4, %4 : tensor<32x1000xf32>
    %35 = stablehlo.subtract %cst, %34 : tensor<32x1000xf32>
    %36 = stablehlo.multiply %33, %35 : tensor<32x1000xf32>
    %37 = stablehlo.reduce(%36 init: %cst_0) applies stablehlo.add across dimensions = [1] : (tensor<32x1000xf32>, tensor<f32>) -> tensor<32xf32>
    %38 = stablehlo.dot_general %36, %arg8, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x1000xf32>, tensor<1000x1xf32>) -> tensor<32x1xf32>
    %39 = stablehlo.reshape %38 : (tensor<32x1xf32>) -> tensor<1x32xf32>
    return %39, %37, %32, %31, %26, %25, %20, %18 : tensor<1x32xf32>, tensor<32xf32>, tensor<32x32xf32>, tensor<32xf32>, tensor<32x32xf32>, tensor<32xf32>, tensor<32x1xf32>, tensor<1xf32>
  }
}

never mind for the gradient calculation its just the |> first

Looking at it I would tell its not an issue for performance but indeed little scary on the user side

So the conclusion (for now at least) is weird but likely not a real problem?

Speaking of weird issues, I’m getting one from Enzyme:

EnzymeRuntimeActivityError: Detected potential need for runtime activity.
Constant memory is stored (or returned) to a differentiable variable and correctness cannot be guaranteed with static activity analysis.

This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#faq-runtime-activity).

If Enzyme should be able to prove this use non-differentable, open an issue!

To work around this issue, either:

   a) rewrite this variable to not be conditionally active (fastest performance, slower to setup), or

   b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.



Failure within method:



_match_eltype(::Flux.ConvTranspose{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}, ::Type{Float32}, ::Array{Float64, 4})

     @ Flux ~/.julia/packages/Flux/hrg9M/src/layers/stateless.jl:58



Hint: catch this exception as `err` and call `code_typed(err)` to inspect the surrounding code.



Mismatched activity for:   store atomic ptr addrspace(10) %55, ptr addrspace(11) %59 unordered, align 8, !dbg !561, !tbaa !563, !alias.scope !567, !noalias !568 const val:   %55 = call fastcc nonnull ptr addrspace(10) @julia_summary_253798(ptr addrspace(10) %"x::Array") #548, !dbg !557

Type tree: {[-1]:Pointer}

 LLVM view of erring value:     %55 = call fastcc nonnull ptr addrspace(10) @julia_summary_253798(ptr addrspace(10) %"x::Array") #548, !dbg !557



Stacktrace:

 [1] macro expansion

   @ ./logging/logging.jl:419

 [2] _match_eltype

   @ ~/.julia/packages/Flux/hrg9M/src/layers/stateless.jl:60

From some googling, it seems like this is an issue with Enzyme itself at the moment, so I may wait and see if it gets resolved.

One of the reasons I was asking about the linear indexing warning is that going from calculating the gradient to updating the model is taking an enormous amount of time. The gradient calculation seems to be done in about 20ms, but Flux.update! takes between 20-35s. That seems excessive, but other than dumping profiling data I’m not sure where to start investigating.

Based on the Reactant examples, I’m not supposed to @compile Flux.update!, right? the compiling should stop at the gradient calculation?

Replace Enzyme.Reverse with `Enzyme.set_runtime_activity(Enzyme.Reverse) for now.

Weird use Optimisers and Enzyme directly to see. I dont know if Flux traced the loop right for now but Lux did so for now you must do the training loop by hand .

That error says that you have activity-unstable code [aka two entry points to the variable, one which is differentiable, one which is not – it should be explaine din the linked doc].

That said, definitionally Reactant should remove all such code paths, so that implies you are not compiling the gradient?

No you should very much compile the whole update step. This will result in significant performance gains as Reactant can fuse together the update and derivative calculation [reducing memory and introducing fusion]

I think I know what was causing that, especially after checking the docs. I’m trying to recreate/leverage some multiple dispatch, but the way I’m going about it is using some conditionals. My loss function has several forms that differ on number of inputs, so I’ve been trying to wrangle the compiled versions under a single function to make things a bit neater. Perhaps I just need to bite the bullet, unless there’s a better way to handle multiple dispatch when compiling?

Edit: I just found that I’m not the first to have this issue, it seems to be specific to the Adam optimizer. Still, the solution from the previous issue was to delete and redownload all dependencies. I’m trying this on a new notebook, so it should be “fresh”?

Edit 2: Just changed to Descent and it works. I guess this is just a data point for the Adam issue?

So when trying to use @compile to capture the whole updating step with a simpler loss function, I’m running into a new error. The code for reference is

opt = Flux.setup(Adam(), m)
ora = opt|>dev

function f_grad(f, m, z)
	Enzyme.gradient(Enzyme.Reverse, f, m, z) |> first
end

function ud(f, m, z, opt)
	gs = f_grad(f, m, z)
	Optimisers.update(opt, m, gs)
end

@compile ud(𝒱ₙₙ, mra, zra, ora)

where m is the model and z is some random data, with mra and zra being the same objects moved to RArrays. The error I’m getting is

MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})

The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.



Closest candidates are:

  (::Type{T})(::T) where T<:Number

   @ Core boot.jl:965

  Float32(::IrrationalConstants.Halfπ)

   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/RokwY/src/macro.jl:112

  Float32(::IrrationalConstants.Logten)

   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/RokwY/src/macro.jl:112

  ...

with the full stack trace

call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Reactant.TracedRNumber{…}) ...show types...
from 
utils.jl
cvt1
from 
essentials.jl:612
ntuple
from 
ntuple.jl:51
convert
from 
essentials.jl:614
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Tuple{…}) ...show types...
from 
utils.jl
cvt1
from 
essentials.jl:612
ntuple
from 
ntuple.jl:52
convert
from 
essentials.jl:614
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(setproperty!), ::Optimisers.Leaf{…}, ::Symbol, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#11
from 
interface.jl:96
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#11", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:389
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:389
mapvalue
from 
utils.jl:2
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:385
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:386
mapvalue
from 
utils.jl:2
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:385
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:386
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::gan_model) ...show types...
from 
utils.jl
update!
from 
interface.jl:77
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.update!), ::@NamedTuple{…}, ::gan_model, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.update), ::@NamedTuple{…}, ::gan_model, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 

It feels like it should be fairly straighforward, but perhaps I’m missing something here?

Seems like a bug in Optimizers.jl where it hardcodes Float explicitly [and thus cannot hold the reactant float]. We fixed some of these in the past but it looks like we missed one. File an issue on Optimizers.jl with the MWE and tag myself and @avikpal?

Also feel free to take a first stab at it by looking at the other ones, if you’d like!

issue filed in unnecessary jit on setup for Reactant · Issue #226 · FluxML/Optimisers.jl · GitHub

this works if it helps you

using Reactant, Flux, Optimisers, Enzyme
Reactant.allowscalar(true)
model = Chain(Dense(1=>16,tanh),Dense(16=>1)) |> Reactant.to_rarray;
x = rand(Float32,1,1000) |> Reactant.to_rarray;
y = rand(Float32,1,1000) |> Reactant.to_rarray;
loss(model,x,y) = sum(abs2,model(x) .- y);
opt = @jit Optimisers.setup(Adam(0.01f0),model);
function single_step(opt,model,x,y)
    g = Enzyme.gradient(Enzyme.Reverse,loss,model,Const(x),Const(y))
    Optimisers.update!(opt,model,g[1])
    return nothing
end
@jit loss(model,x,y) # ConcretePJRTNumber{Float32, 1}(342.2859f0)
@jit single_step(opt,model,x,y)
@jit loss(model,x,y) # ConcretePJRTNumber{Float32, 1}(252.75954f0)
@jit single_step(opt,model,x,y)
@jit loss(model,x,y) # ConcretePJRTNumber{Float32, 1}(184.58466f0)

if you will change the train step at some point don’t forget to trace it.
I think something people forget is to jit in
@jit Optimisers.setup(Adam(0.01f0),model);
The issue is the following :
the internal states of Adam (beta …) ends up being Float when not traced but they happen to change within the update! function and does not like changing types.

julia> opt = @jit Optimisers.setup(Adam(0.01f0),model)
(layers = ((weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 2, 1}(Float32[0.0; 0.0; … ; 0.0; 0.0;;]), ConcretePJRTArray{Float32, 2, 1}(Float32[0.0; 0.0; … ; 0.0; 0.0;;]), (ConcretePJRTNumber{Float32, 1}(0.9), ConcretePJRTNumber{Float32, 1}(0.999)))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 1, 1}(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ConcretePJRTArray{Float32, 1, 1}(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (ConcretePJRTNumber{Float32, 1}(0.9), ConcretePJRTNumber{Float32, 1}(0.999)))), σ = ()), (weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 2, 1}(Float32[0.0 0.0 … 0.0 0.0]), ConcretePJRTArray{Float32, 2, 1}(Float32[0.0 0.0 … 0.0 0.0]), (ConcretePJRTNumber{Float32, 1}(0.9), ConcretePJRTNumber{Float32, 1}(0.999)))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 1, 1}(Float32[0.0]), ConcretePJRTArray{Float32, 1, 1}(Float32[0.0]), (ConcretePJRTNumber{Float32, 1}(0.9), ConcretePJRTNumber{Float32, 1}(0.999)))), σ = ())),)

julia> typeof(opt)
@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, σ::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{ConcretePJRTNumber{Float32, 1}, ConcretePJRTNumber{Float32, 1}}}}, σ::Tuple{}}}}

julia> opt = Optimisers.setup(Adam(0.01f0),model)
(layers = ((weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 2, 1}(Float32[0.0; 0.0; … ; 0.0; 0.0;;]), ConcretePJRTArray{Float32, 2, 1}(Float32[0.0; 0.0; … ; 0.0; 0.0;;]), (0.9, 0.999))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 1, 1}(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), ConcretePJRTArray{Float32, 1, 1}(Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (0.9, 0.999))), σ = ()), (weight = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 2, 1}(Float32[0.0 0.0 … 0.0 0.0]), ConcretePJRTArray{Float32, 2, 1}(Float32[0.0 0.0 … 0.0 0.0]), (0.9, 0.999))), bias = Leaf(Adam(eta=0.01, beta=(0.9, 0.999), epsilon=1.0e-8), (ConcretePJRTArray{Float32, 1, 1}(Float32[0.0]), ConcretePJRTArray{Float32, 1, 1}(Float32[0.0]), (0.9, 0.999))), σ = ())),)

julia> typeof(opt)
@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 2, 1}, ConcretePJRTArray{Float32, 2, 1}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{ConcretePJRTArray{Float32, 1, 1}, ConcretePJRTArray{Float32, 1, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}

I don’t know how to fix it though, maybe dispatch on setup but since it accepts any kind of structure its complicated. It should be fix though, since people only call this function once and may not understand the need to compile it.

Mentioned on one of the issues, but you need to still run to_rarray on the optimizer itself [tho you shouldn’t need to jit]

Like with the Adam optimizer, are there any compatibility issues known with the binarycrossentropy loss function? I’m trying to compile it and I get

MethodError: no method matching var"#binarycrossentropy#21"(::typeof(Statistics.mean), ::Reactant.TracedRNumber{Float32}, ::typeof(Flux.Losses.binarycrossentropy), ::Base.ReshapedArray{Reactant.TracedRNumber{Float32}, 4, Reactant.TracedRArray{Float32, 2}, Tuple{}}, ::Float32)

The function `#binarycrossentropy#21` exists, but no method is defined for this combination of argument types.



Closest candidates are:

  var"#binarycrossentropy#21"(::Any, ::Real, ::typeof(Flux.Losses.binarycrossentropy), ::Any, ::Any)

   @ Flux ~/.julia/packages/Flux/hrg9M/src/losses/functions.jl:319

Which looking at it might be due to a hardcoded Real type? crossentropy has the same issue it seems, but not logitbinarycrossentropy oddly enough.

Yeah that just seems like a hardcoded type in flux. Open a PR to flux and change it to abstractfloat?

I don’t really follow why this goes wrong when looking at the code in Flux it’s

function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ))
  _check_sizes(ŷ, y)
  agg(@.(-xlogy(y, ŷ + eps) - xlogy(1 - y, 1 - ŷ + eps)))
end

Where eps is Real which is an abstract type sitting above abstract float right? I’m not well versed in the type system so I must be missing something.

I think that’s it (but confusing from the kwarg mangle), presumable y is a TracedRArray{Float64} so eltype(y) == a TracedRNumber{Float64}, which is not a subtype of real (but is of Number).

Got it. Thanks for clearing that up.