Okay thanks! Could you help me understand how to make it work with a vector output? The example in the docs is a scalar output, and I don’t understand how I’m supposed to pass the cotangent:
julia> using Enzyme
julia> g(x) = abs2.(x)
g (generic function with 1 method)
julia> x = [2.0, 3.0];
julia> dx = zero(x);
julia> dy = [5.0, 7.0];
julia> forw, rev = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(g)}, Duplicated, Duplicated{typeof(x)})
(Enzyme.Compiler.AugmentedForwardThunk{Ptr{Nothing}, Const{typeof(g)}, Duplicated{Vector{Float64}}, Tuple{Duplicated{Vector{Float64}}}, Val{1}, Val{true}(), @NamedTuple{1, 2, 3, 4, 5::Bool, 6::Bool, 7::Core.LLVMPtr{Float64, 0}, 8::Core.LLVMPtr{Float64, 0}, 9::Core.LLVMPtr{Float64, 0}, 10::Core.LLVMPtr{Float64, 0}, 11::Core.LLVMPtr{Float64, 0}, 12::Core.LLVMPtr{Float64, 0}, 13::Core.LLVMPtr{Float64, 0}, 14::Core.LLVMPtr{Float64, 0}, 15::Core.LLVMPtr{Float64, 0}}}(Ptr{Nothing} @0x00007871c8880450), Enzyme.Compiler.AdjointThunk{Ptr{Nothing}, Const{typeof(g)}, Duplicated{Vector{Float64}}, Tuple{Duplicated{Vector{Float64}}}, Val{1}, @NamedTuple{1, 2, 3, 4, 5::Bool, 6::Bool, 7::Core.LLVMPtr{Float64, 0}, 8::Core.LLVMPtr{Float64, 0}, 9::Core.LLVMPtr{Float64, 0}, 10::Core.LLVMPtr{Float64, 0}, 11::Core.LLVMPtr{Float64, 0}, 12::Core.LLVMPtr{Float64, 0}, 13::Core.LLVMPtr{Float64, 0}, 14::Core.LLVMPtr{Float64, 0}, 15::Core.LLVMPtr{Float64, 0}}}(Ptr{Nothing} @0x00007871c8880a40))
julia> tape, y, shadow_y = forw(Const(g), Duplicated(x, dx))
(var"1" = @NamedTuple{1, 2, 3, 4, 5::Bool, 6::Bool, 7::Core.LLVMPtr{Float64, 0}, 8::Core.LLVMPtr{Float64, 0}, 9::Core.LLVMPtr{Float64, 0}, 10::Core.LLVMPtr{Float64, 0}, 11::Core.LLVMPtr{Float64, 0}, 12::Core.LLVMPtr{Float64, 0}, 13::Core.LLVMPtr{Float64, 0}, 14::Core.LLVMPtr{Float64, 0}, 15::Core.LLVMPtr{Float64, 0}}(([0.0, 0.0], [4.0, 9.0], nothing, nothing, false, false, Core.LLVMPtr{Float64, 0}(0x000078715e006610), Core.LLVMPtr{Float64, 0}(0x7ffffffffffffffe), Core.LLVMPtr{Float64, 0}(0x000078715e0065e0), Core.LLVMPtr{Float64, 0}(0x000078715e181730), Core.LLVMPtr{Float64, 0}(0x000078725cb10311), Core.LLVMPtr{Float64, 0}(0x000078724ca32f60), Core.LLVMPtr{Float64, 0}(0x000078725ca6bee0), Core.LLVMPtr{Float64, 0}(0x000078724ca32dc0), Core.LLVMPtr{Float64, 0}(0x0000000009900c00))), var"2" = [4.0, 9.0], var"3" = [0.0, 0.0])
julia> y
2-element Vector{Float64}:
4.0
9.0
julia> rev(Const(g), Duplicated(x, dx), Duplicated(y, dy), tape)
ERROR: AssertionError: length(argtypes) + needs_tape == length(argexprs)
Stacktrace:
⋮ internal @ Enzyme.Compiler, GPUCompiler, Core, Unknown
[6] (::Enzyme.Compiler.AdjointThunk{Ptr{…}, Const{…}, Duplicated{…}, Tuple{…}, Val{…}, @NamedTuple{…}})(::Const{typeof(g)}, ::Duplicated{Vector{…}}, ::Vararg{Any})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5004
Use `err` to retrieve the full stack trace.
Some type information was truncated. Use `show(err)` to see complete types.