I’m using ForwardDiff to differentiate a function that will be computed across a cluster. When I run the code using a single process everything works fine. However, when I use multiple processes, I get the error Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#3#4"{DTable}, Int64} and ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}
(full stacktrace at end of post).
Here’s an MWE to replicate the problem. If you run this with command-line argument 0 (so no workers are started), it works; with any higher number it crashes with the above error:
using Distributed
length(ARGS) ≥ 1 ? addprocs(parse(Int64, ARGS[1])) : addprocs()
# comment this out if using global environment
@everywhere begin
using Pkg
Pkg.activate(Base.source_dir())
end
@everywhere begin
using ForwardDiff, Dagger
function do_calc(x, t)
fetch(reduce(+, map(r -> (val=x*r.a,), t))).val
end
end
function main()
table = (a=[1, 1, 1, 1, 1], b=[6, 7, 8, 9, 10]);
d = DTable(table, 2)
deriv = ForwardDiff.derivative(x -> do_calc(x, d), 42)
println(deriv) # should be 5
end
main()
I’ve read a bit about ForwardDiff’s tagging system, most helpfully this issue, which seems to suggest that the tagging is important for exactly the situation I have here–a closure over the variable being differentiated. I suspect the issue is that tags are being created on different workers and then cannot be compared when they are added together to produce the final result, but I’m at a loss how to fix it.
Full MWE with project files: https://gist.github.com/mattwigway/7d4eaa2576e2720ba3f1a116cbdf44d2
Full stack trace:
ERROR: LoadError: Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#3#4"{DTable}, Int64} and ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}
Stacktrace:
[1] partials
@ ~/.julia/packages/ForwardDiff/tZ5o1/src/dual.jl:111 [inlined]
[2] extract_derivative(#unused#::Type{ForwardDiff.Tag{var"#3#4"{DTable}, Int64}}, y::ForwardDiff.Dual{ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}, ForwardDiff.Dual{ForwardDiff.Tag{Serialization.__deserialized_types__.var"#3#4"{DTable}, Int64}, Int64, 1}, 1})
@ ForwardDiff ~/.julia/packages/ForwardDiff/tZ5o1/src/derivative.jl:81
[3] derivative
@ ~/.julia/packages/ForwardDiff/tZ5o1/src/derivative.jl:14 [inlined]
[4] main()
@ Main ~/DaggerDiffErr/forwarddiff_tagging_issue.jl:26
[5] top-level scope
@ ~/DaggerDiffErr/forwarddiff_tagging_issue.jl:30