ForwardDiff error "Cannot determine ordering of Dual tags" when differentiating distributed function with closure

I’m using ForwardDiff to differentiate a function that will be computed across a cluster. When I run the code using a single process everything works fine. However, when I use multiple processes, I get the error Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#3#4"{DTable}, Int64} and ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64} (full stacktrace at end of post).

Here’s an MWE to replicate the problem. If you run this with command-line argument 0 (so no workers are started), it works; with any higher number it crashes with the above error:

using Distributed
length(ARGS) ≥ 1 ? addprocs(parse(Int64, ARGS[1])) : addprocs()

# comment this out if using global environment
@everywhere begin
    using Pkg
    Pkg.activate(Base.source_dir())
end

@everywhere begin
    using ForwardDiff, Dagger

    function do_calc(x, t)
        fetch(reduce(+, map(r -> (val=x*r.a,), t))).val
    end
end

function main()
    table = (a=[1, 1, 1, 1, 1], b=[6, 7, 8, 9, 10]);
    d = DTable(table, 2)

    deriv = ForwardDiff.derivative(x -> do_calc(x, d), 42)
    println(deriv)  # should be 5
end

main()

I’ve read a bit about ForwardDiff’s tagging system, most helpfully this issue, which seems to suggest that the tagging is important for exactly the situation I have here–a closure over the variable being differentiated. I suspect the issue is that tags are being created on different workers and then cannot be compared when they are added together to produce the final result, but I’m at a loss how to fix it.

Full MWE with project files: https://gist.github.com/mattwigway/7d4eaa2576e2720ba3f1a116cbdf44d2

Full stack trace:

ERROR: LoadError: Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#3#4"{DTable}, Int64} and ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}
Stacktrace:
 [1] partials
   @ ~/.julia/packages/ForwardDiff/tZ5o1/src/dual.jl:111 [inlined]
 [2] extract_derivative(#unused#::Type{ForwardDiff.Tag{var"#3#4"{DTable}, Int64}}, y::ForwardDiff.Dual{ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}, ForwardDiff.Dual{ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}, Int64, 1}, 1})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/tZ5o1/src/derivative.jl:81
 [3] derivative
   @ ~/.julia/packages/ForwardDiff/tZ5o1/src/derivative.jl:14 [inlined]
 [4] main()
   @ Main ~/DaggerDiffErr/forwarddiff_tagging_issue.jl:26
 [5] top-level scope
   @ ~/DaggerDiffErr/forwarddiff_tagging_issue.jl:30

This issue is documented in #320 in ForwardDiff. I’ve managed to work around it by avoiding closures and non-Base types in function calls.