Memory allocations in Flux evaluation/training (w/MVE)

Edit: I just found that I’m not the first to have this issue, it seems to be specific to the Adam optimizer. Still, the solution from the previous issue was to delete and redownload all dependencies. I’m trying this on a new notebook, so it should be “fresh”?

Edit 2: Just changed to Descent and it works. I guess this is just a data point for the Adam issue?

So when trying to use @compile to capture the whole updating step with a simpler loss function, I’m running into a new error. The code for reference is

opt = Flux.setup(Adam(), m)
ora = opt|>dev

function f_grad(f, m, z)
	Enzyme.gradient(Enzyme.Reverse, f, m, z) |> first
end

function ud(f, m, z, opt)
	gs = f_grad(f, m, z)
	Optimisers.update(opt, m, gs)
end

@compile ud(𝒱ₙₙ, mra, zra, ora)

where m is the model and z is some random data, with mra and zra being the same objects moved to RArrays. The error I’m getting is

MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})

The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.



Closest candidates are:

  (::Type{T})(::T) where T<:Number

   @ Core boot.jl:965

  Float32(::IrrationalConstants.Halfπ)

   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/RokwY/src/macro.jl:112

  Float32(::IrrationalConstants.Logten)

   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/RokwY/src/macro.jl:112

  ...

with the full stack trace

call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Reactant.TracedRNumber{…}) ...show types...
from 
utils.jl
cvt1
from 
essentials.jl:612
ntuple
from 
ntuple.jl:51
convert
from 
essentials.jl:614
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Tuple{…}) ...show types...
from 
utils.jl
cvt1
from 
essentials.jl:612
ntuple
from 
ntuple.jl:52
convert
from 
essentials.jl:614
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(convert), ::Type{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(setproperty!), ::Optimisers.Leaf{…}, ::Symbol, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#11
from 
interface.jl:96
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#11", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:389
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Conv{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:389
mapvalue
from 
utils.jl:2
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:385
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:386
mapvalue
from 
utils.jl:2
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::Tuple{…}, ::Tuple{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:385
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::Flux.Chain{…}) ...show types...
from 
utils.jl
map
from 
tuple.jl:386
map
from 
namedtuple.jl:263
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(map), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.mapvalue), ::Optimisers.var"#9#10"{…}, ::@NamedTuple{…}, ::@NamedTuple{…}) ...show types...
from 
utils.jl
#_update!#7
from 
interface.jl:85
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::Optimisers.var"##_update!#7", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Optimisers._update!), ::@NamedTuple{…}, ::gan_model) ...show types...
from 
utils.jl
update!
from 
interface.jl:77
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.update!), ::@NamedTuple{…}, ::gan_model, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 
call_with_reactant(::Reactant.EnsureReturnType{…}, ::typeof(Optimisers.update), ::@NamedTuple{…}, ::gan_model, ::gan_model) ...show types...
from 
utils.jl
call_with_reactant
from 

It feels like it should be fairly straighforward, but perhaps I’m missing something here?