According to the Enzyme.jl documentation the “Enzyme differentiates arbitrary multivariate vector functions as the most general case in automatic differentiation”, however I encountered critical errors (including the kernel crash) just by using the basic arithmetic operations.
Here’s a minimal working example (MWE):
using Enzyme
foo1(x, y) = x[1] * y[1] + x[2] * y[2]
foo2(x, y) = x[1] * y[1] + x[2] * y[2] + 1.0
foo3(x, y) = x' * y
foo4(x, y) = x' * y + 1.0
foo5(x, y) = x .* y
foo6(x, y) = sum(x .* y)
foo7(x, y) = x + y
foo8(x, y) = sum(x + y)
foo9(x, y) = sum(foo7(x, y))
foo10(x, y) = [x[1] + y[1], x[2] + y[2]]
foo11(x, y) = x .+ y
foo12(x, y) = x .+ y .+ 1.0
foo13(x, y) = sum(foo12(x, y))
function test_func(func)
x = [2.0, 3.0]
y = [4.0, 1.0]
grad_reverse_fail = nothing
grad_forward_fail = nothing
try
Enzyme.gradient(Reverse, func, x, y)
catch e
grad_reverse_fail = e
end
try
Enzyme.gradient(Forward, func, x, y)
catch e
grad_forward_fail = e
end
if grad_reverse_fail !== nothing
println("Reverse gradient failed for $(func): ", grad_reverse_fail)
else
println("Reverse gradient succeeded for $(func)")
end
if grad_forward_fail !== nothing
println("Forward gradient failed for $(func): ", grad_forward_fail)
else
println("Forward gradient succeeded for $(func)")
end
end
test_func(foo1) # foo1 works
test_func(foo2) # foo2 works
test_func(foo3) # foo3 works
test_func(foo4) # foo4 works
test_func(foo5) # Reverse fails: Enzyme mutability error; Forward works
test_func(foo6) # foo6 works
test_func(foo7) # Reverse fails: Enzyme mutability error; Forward works
test_func(foo8) # foo8 works
test_func(foo9) # foo9 works
test_func(foo10) # Reverse fails: Enzyme mutability error; Forward works
test_func(foo11) # Reverse fails; Forward works
test_func(foo13) # Reverse fails; Forward fails: EnzymeRuntimeActivityError
test_func(foo12) # foo12 last, as it fails catastrophically and crashes the kernel
Interesting examples:
foo7 and foo9, where foo9 calls foo7 and suddenly the differentiation works again (presumably the problem is that the function returns more than one argument; for context documentation says that Enzyme should work on any f: R^n → R^m)
foo12 and foo13, where foo12 terminates the julia kernel, while foo13 (which is calling foo12) fails, but does not crash the kernel.
Why are these simple examples failing? What can I realistically expect from Enzyme with these functionalities? I’m considering Enzyme for sensitivity analysis of differential equations - should I proceed with this package, or is it currently unreliable for such applications?
Thanks for your insights!