I’m using Enzyme to differentiate through a closure, and BatchDuplicated
on the closure itself doesn’t seem to work, unlike Duplicated
.
using Enzyme
struct MyClosure{A}
a::A
end
function (mc::MyClosure)(x)
# computes x^2 using internal storage
mc.a[1] = x
return mc.a[1]^2
end
g = MyClosure([0.0])
g(3.0) # 9.0
g_and_dg = Duplicated(g, make_zero(g))
x_and_dx = Duplicated(3.0, 5.0)
autodiff(Forward, g_and_dg, Duplicated, x_and_dx) # (9.0, 30.0)
g_and_dgs = BatchDuplicated(g, (make_zero(g), make_zero(g)))
x_and_dxs = BatchDuplicated(3.0, (5.0, 7.0))
autodiff(Forward, g_and_dgs, BatchDuplicated, x_and_dxs) # error
The last call triggers the following error:
ERROR: TypeError: in ccall argument 3, expected MyClosure{Vector{Float64}}, got a value of type Tuple{MyClosure{Vector{Float64}}, MyClosure{Vector{Float64}}}
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:7151 [inlined]
[2] enzyme_call
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6760 [inlined]
[3] ForwardModeThunk
@ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6640 [inlined]
[4] autodiff(::ForwardMode{…}, f::BatchDuplicated{…}, ::Type{…}, args::BatchDuplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:437
[5] top-level scope
Some type information was truncated. Use `show(err)` to see complete types.
@wsmoses do you have any clue what I did wrong? For reference, it happened in this PR.