I tried Cache
import Random
using DifferentiationInterface,LinearAlgebra,BenchmarkTools
import Enzyme
@fastmath logistic_sigmoid(x) = 1.0f0 / (1 + exp(-x))
function loss(
Wmid, Wctx,
nvocab, ndim,
tok_mid, tok_ctx, x
)
Wmid = reshape(Wmid, nvocab, ndim)
Wctx = reshape(Wctx, nvocab, ndim)
nll ::Float32 = 0.0f0
@inbounds @fastmath for c in eachindex(x)
Xij = x[c]
dotprod = dot(@view(Wmid[tok_mid[c],:]),@view(Wctx[tok_ctx[c],:]))
nll += Xij * log(logistic_sigmoid(dotprod))
nll += (1 - Xij) * log(1 - logistic_sigmoid(dotprod))
end
nll = nll / length(x)
nll
end
function train(rng, nvocab::Integer, nsamples::Integer, ndim::Integer, backend)
ntrues = nsamples ÷ 2
tok_mid = rand(rng, 1:nvocab, nsamples)
tok_ctx = rand(rng, 1:nvocab, nsamples)
x = [trues(ntrues); falses(nsamples - ntrues)]
weights_mid = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
weights_ctx = 0.1f0 .* randn(rng, Float32, nvocab * ndim)
# @info "Computing gradient..." backend
dweights_mid = Enzyme.make_zero(weights_mid)
# dweights_ctx = Enzyme.make_zero(weights_ctx)
gradient!(loss,dweights_mid,backend,weights_mid,Cache(weights_ctx),Constant(nvocab),Constant(ndim),Constant(tok_mid),Constant(tok_ctx),Constant(x))
# Enzyme.autodiff(Enzyme.Reverse, loss, Enzyme.Duplicated(weights_mid, dweights_mid), Enzyme.DuplicatedNoNeed(weights_ctx, dweights_ctx), Enzyme.Const(nvocab), Enzyme.Const(ndim), Enzyme.Const(tok_mid), Enzyme.Const(tok_ctx), Enzyme.Const(x))
dweights_mid
end
I get this error
ERROR: MethodError: no method matching _translate(::AutoEnzyme{…}, ::EnzymeCore.ReverseMode{…}, ::Val{…}, ::Cache{…})
The function `_translate` exists, but no method is defined for this combination of argument types.
Closest candidates are:
_translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::DifferentiationInterface.FunctionContext) where B
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:56
_translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::Union{DifferentiationInterface.BackendContext, Constant}) where B
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:50
Stacktrace:
[1] (::DifferentiationInterfaceEnzymeExt.var"#3#4"{1, AutoEnzyme{…}, EnzymeCore.ReverseMode{…}})(c::Cache{Vector{…}})
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:66
[2] map
@ .\tuple.jl:358 [inlined]
[3] translate
@ C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:65 [inlined]
[4] gradient!(::typeof(loss), ::Vector{…}, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:266
[5] gradient!(::typeof(loss), ::Vector{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:55
[6] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
@ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
[7] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
[8] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
[9] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
[10] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
[11] #invokelatest#2
@ .\essentials.jl:1055 [inlined]
[12] invokelatest
@ .\essentials.jl:1052 [inlined]
[13] #lineartrial#46
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
[14] lineartrial
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
[15] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
[16] tune!
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
[17] tune!(b::BenchmarkTools.Benchmark)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
[18] top-level scope
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.
train (generic function with 1 method)
Random.Xoshiro(0xc3b1bf2ea9b69425, 0xfb15c0017deb2d31, 0xb5c67bbde9fbe995, 0x10b9c73e8107af27, 0xb6e7dc9e3c9975b2)
AutoEnzyme()
ERROR: MethodError: no method matching gradient!(::typeof(loss), ::AutoEnzyme{…}, ::Vector{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
The function `gradient!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
gradient!(::F, ::Any, ::ADTypes.AbstractADType, ::Any, ::Context...) where {F, C}
@ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:51
gradient!(::F, ::Any, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep, ::AutoEnzyme{<:Union{Nothing, EnzymeCore.ReverseMode}}, ::Any, ::Context...) where {F, C}
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:254
gradient!(::Any, ::Any, ::DifferentiationInterface.NoGradientPrep, ::AutoZygote, ::Any, ::Union{DifferentiationInterface.BackendContext, Constant, DifferentiationInterface.FunctionContext}...) where C
@ DifferentiationInterfaceZygoteExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:147
...
Stacktrace:
[1] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
@ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
[2] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
[3] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
[4] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
[5] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
[6] #invokelatest#2
@ .\essentials.jl:1055 [inlined]
[7] invokelatest
@ .\essentials.jl:1052 [inlined]
[8] #lineartrial#46
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
[9] lineartrial
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
[10] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
[11] tune!
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
[12] tune!(b::BenchmarkTools.Benchmark)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
[13] top-level scope
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.
train (generic function with 1 method)
Random.Xoshiro(0xc3b1bf2ea9b69425, 0xfb15c0017deb2d31, 0xb5c67bbde9fbe995, 0x10b9c73e8107af27, 0xb6e7dc9e3c9975b2)
AutoEnzyme()
ERROR: MethodError: no method matching _translate(::AutoEnzyme{…}, ::EnzymeCore.ReverseMode{…}, ::Val{…}, ::Cache{…})
The function `_translate` exists, but no method is defined for this combination of argument types.
Closest candidates are:
_translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::DifferentiationInterface.FunctionContext) where B
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:56
_translate(::AutoEnzyme, ::EnzymeCore.Mode, ::Val{B}, ::Union{DifferentiationInterface.BackendContext, Constant}) where B
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:50
Stacktrace:
[1] (::DifferentiationInterfaceEnzymeExt.var"#3#4"{1, AutoEnzyme{…}, EnzymeCore.ReverseMode{…}})(c::Cache{Vector{…}})
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:66
[2] map
@ .\tuple.jl:358 [inlined]
[3] translate
@ C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\utils.jl:65 [inlined]
[4] gradient!(::typeof(loss), ::Vector{…}, ::DifferentiationInterfaceEnzymeExt.EnzymeGradientPrep{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterfaceEnzymeExt C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\ext\DifferentiationInterfaceEnzymeExt\reverse_onearg.jl:266
[5] gradient!(::typeof(loss), ::Vector{…}, ::AutoEnzyme{…}, ::Vector{…}, ::Cache{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…}, ::Constant{…})
@ DifferentiationInterface C:\Users\yolha\.julia\packages\DifferentiationInterface\a2pZk\src\fallbacks\no_prep.jl:55
[6] train(rng::Random.Xoshiro, nvocab::Int64, nsamples::Int64, ndim::Int64, backend::AutoEnzyme{Nothing, Nothing})
@ Main c:\Users\yolha\Desktop\bench_py_jl\benchmark.jl:39
[7] var"##core#237"(rng#235::Random.Xoshiro, bck#236::AutoEnzyme{Nothing, Nothing})
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:598
[8] var"##sample#238"(::Tuple{Random.Xoshiro, AutoEnzyme{Nothing, Nothing}}, __params::BenchmarkTools.Parameters)
@ Main C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:607
[9] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:186
[10] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:181
[11] #invokelatest#2
@ .\essentials.jl:1055 [inlined]
[12] invokelatest
@ .\essentials.jl:1052 [inlined]
[13] #lineartrial#46
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:51 [inlined]
[14] lineartrial
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:50 [inlined]
[15] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:299
[16] tune!
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288 [inlined]
[17] tune!(b::BenchmarkTools.Benchmark)
@ BenchmarkTools C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:288
[18] top-level scope
@ C:\Users\yolha\.julia\packages\BenchmarkTools\1i1mY\src\execution.jl:728
Some type information was truncated. Use `show(err)` to see complete types.
I open an issue ? I can try to lower all this a bit (same on DI#main)