I played around with second derivatives and ForwardDiff.jl and found something unexpected, so I would like to ask about it. Basically, the generated LLVM representation is multiplying a whole code branch by zero before returning. Here’s a simple example:
julia> using ForwardDiff: derivative
julia> f(x) = x^3
f (generic function with 1 method)
julia> ∂f(x) = derivative(f, x)
∂f (generic function with 1 method)
julia> ∂²f(x) = derivative(∂f, x)
∂²f (generic function with 1 method)
julia> f(1.0), ∂f(1.0), ∂²f(1.0) # everything OK so far
(1.0, 3.0, 6.0)
And here are the LLVM representations for both the first- and second-order derivatives (with some annotations made by me - please tell me if I’m reading correctly):
Unfortunately IEEE floating points rules are not kind to making code optimiseable.
In particular the there exists a value x such that 0.0 * x != 0.0.
Namely x = NaN.
Thus the compiler is not generally allowed to simply remove that chain of expressions.
In theory ForwardDiff2.jl might handle thus better since it uses ChainRule’s types.
ChainRules has a strong Zero() object which does have Zero() * x = Zero() for all x.
I don’t know if it will show up in this circumstance.
Also ForwardDiff2 is an early experiment and is on hiatus last I heard.
Hi, @oxinabox thanks for the reply and for the suggestion. I’ll try to use ForwardDiff2.jl. Are there any other AD libraries that currently support ChainRules.jl?
The reason this doesn’t help is that it doesn’t do anything, as the macro can’t see any operations it knows about.
But adding fastmath to the x^3 alone was my first guess, too, and that still doesn’t improve what’s generated. This time because it gets thrown away at the first dispatch, DiffRules defines no methods for FastMath functions, and the fallback is to call the un-fast variant:
I did not recommend that.
Yingbo does not recommend that; because it is an experiment that is on hiatus.
That info was much more on medium term outlook, than a suggestion.
Are there any other AD libraries that currently support ChainRules.jl?
Reverse mode: Zygote, and there is a branch of Nabla.
Forward and Reverse: ReversePropagation.jl
Reverse mode for some types, but not for rules: Yota.jl
but it doesn’t mean on any of this that this is solved.
Just that there is an option to be
ChainRules.jl does define fast math rules for FastMath functions.
And the way the code works is really cool.
Even if i do say so myself.
The fast math rules are a code tranform (basically applying @fastmath) of the original rules
Even having those though is not enough always since @fastmath does not apply recursively.
(ideally @fastmath would act on dynamic scope like a Cassette pass, not lexical scope)
Technically ReversePropogation.jl is only reverse mode. We could (and should) easily make it emit the forward mode. Then when SymbolicUtils.jl supports array symbols, it’ll be a fairly complete system and essentially do this with a switch to apply simplify rules.