BatchDuplicated functions in Enzyme

I’m using Enzyme to differentiate through a closure, and BatchDuplicated on the closure itself doesn’t seem to work, unlike Duplicated.

using Enzyme

struct MyClosure{A}
    a::A
end

function (mc::MyClosure)(x)
    # computes x^2 using internal storage
    mc.a[1] = x
    return mc.a[1]^2
end

g = MyClosure([0.0])
g(3.0)  # 9.0

g_and_dg = Duplicated(g, make_zero(g))
x_and_dx = Duplicated(3.0, 5.0)
autodiff(Forward, g_and_dg, Duplicated, x_and_dx)  # (9.0, 30.0)

g_and_dgs = BatchDuplicated(g, (make_zero(g), make_zero(g)))
x_and_dxs = BatchDuplicated(3.0, (5.0, 7.0))
autodiff(Forward, g_and_dgs, BatchDuplicated, x_and_dxs)  # error

The last call triggers the following error:

ERROR: TypeError: in ccall argument 3, expected MyClosure{Vector{Float64}}, got a value of type Tuple{MyClosure{Vector{Float64}}, MyClosure{Vector{Float64}}}
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:7151 [inlined]
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6760 [inlined]
 [3] ForwardModeThunk
   @ ~/.julia/packages/Enzyme/Tb3Iu/src/compiler.jl:6640 [inlined]
 [4] autodiff(::ForwardMode{…}, f::BatchDuplicated{…}, ::Type{…}, args::BatchDuplicated{…})
   @ Enzyme ~/.julia/packages/Enzyme/Tb3Iu/src/Enzyme.jl:437
 [5] top-level scope
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses do you have any clue what I did wrong? For reference, it happened in this PR.

Seems like no one has ever used batched forward mode closures before, open an issue?

Fix here: Handle batch closures by wsmoses · Pull Request #1784 · EnzymeAD/Enzyme.jl · GitHub

2 Likes