Still having some problem. Being a bit more explicit now:
In global scope, the following works fine:
using ForwardDiff
import ForwardDiff: value, partials, Dual
h(x) = 3*x
dh(x) = (println("Called"); 3 * one(x))
h(x::Dual) = Dual(h(value(x)), partials(x) * dh(value(x)))
ForwardDiff.derivative(h, 2.0)
This has the output:
julia> ForwardDiff.derivative(h, 2.0)
Called
3.0
which shows that my custom derivative is being called. However, for automatic differentiation, users frequently use closures.
My attempt as in the OP was:
function closure(a)
h(x) = a*x
dh(x) = (println("Called"); a * one(x))
h(x::Dual) = Dual(h(value(x)), partials(x) * dh(value(x)))
ForwardDiff.derivative(h, 2.0)
end
As been said, this fails with UndefVarError: dh not defined
. (This just feels like a straight up bug, dh
is defined!).
Just copy pasting the function bodies into the ::Dual
function works (but this is of course not really user friendly):
julia> function closure2(a)
h(x::Dual) = (println("Called"); Dual(a * value(x), partials(x) * a * one(value(x))))
ForwardDiff.derivative(h, 2.0)
end
closure2 (generic function with 1 method)
julia> closure2(1.0)
Called
1.0
julia> closure2(2.0)
Called
2.0
Trying to do it @mauro3 style:
function closure3(a)
h(x) = a*x
dh(x) = (println("Called"); a * one(x))
(::typeof(h))(x::Dual) = Dual(h(value(x)), partials(x) * dh(value(x)))
ForwardDiff.derivative(h, 2.0)
end
Doing it like this, the custom derivative function is actually never called. Moreover this leads to method redefinition warnings when calling it multiple times.
julia> closure3(2.0)
2.0
julia> closure3(2.0)
WARNING: Method definition (::Main.#h#15{Float64})(ForwardDiff.Dual) in module Main at REPL[3]:5 overwritten at REPL[3]:5.
2.0
I’m kinda out of ideas now. It feels like this should just work.