How to avoid ForwardDiff.jl generating a second-order derivative that wastes flops by eventually multiplying by zero

I played around with second derivatives and ForwardDiff.jl and found something unexpected, so I would like to ask about it. Basically, the generated LLVM representation is multiplying a whole code branch by zero before returning. Here’s a simple example:

julia> using ForwardDiff: derivative

julia> f(x) = x^3
f (generic function with 1 method)

julia> ∂f(x) = derivative(f, x)
∂f (generic function with 1 method)

julia> ∂²f(x) = derivative(∂f, x)
∂²f (generic function with 1 method)

julia> f(1.0), ∂f(1.0), ∂²f(1.0)  # everything OK so far
(1.0, 3.0, 6.0)

And here are the LLVM representations for both the first- and second-order derivatives (with some annotations made by me - please tell me if I’m reading correctly):

julia> @code_llvm debuginfo=:none ∂f(1.0)

define double @"julia_\E2\88\82f_17305"(double) {
top:
  %1 = fmul double %0, %0            # %1 = x * x
  %2 = fmul double %1, 3.000000e+00  # %2 = 3 * x^2
  ret double %2                      # very nice, great success
}

julia> @code_llvm debuginfo=:none ∂²f(1.0)

define double @"julia_\E2\88\82\C2\B2f_17314"(double) {
top:
  %1 = fmul double %0, %0            # %1 = x * x
  %2 = fmul double %0, 2.000000e+00  # %2 = 2 * x
  %3 = fmul double %1, 3.000000e+00  # %3 = 3 * x^2
  %4 = fmul double %2, 3.000000e+00  # %4 = 3 * 2 * x
  %5 = fmul double %3, 0.000000e+00  # %5 = 0 * 3 * x^2  # <= 🤔
  %6 = fadd double %4, %5            # %6 = 6 * x
  ret double %6                      # %1, %3 and %5 are not required!
}

Is this expected? Can this be optimized? I think so since the above does not happen if x is an Int64:

julia> @code_llvm debuginfo=:none ∂²f(1)

define i64 @"julia_\E2\88\82\C2\B2f_17357"(i64) {
top:
  %1 = mul i64 %0, 6
  ret i64 %1
}

Unfortunately IEEE floating points rules are not kind to making code optimiseable.
In particular the there exists a value x such that 0.0 * x != 0.0.
Namely x = NaN.
Thus the compiler is not generally allowed to simply remove that chain of expressions.


In theory ForwardDiff2.jl might handle thus better since it uses ChainRule’s types.
ChainRules has a strong Zero() object which does have Zero() * x = Zero() for all x.
I don’t know if it will show up in this circumstance.
Also ForwardDiff2 is an early experiment and is on hiatus last I heard.

6 Likes

Hi, @oxinabox thanks for the reply and for the suggestion. I’ll try to use ForwardDiff2.jl. Are there any other AD libraries that currently support ChainRules.jl?

Strangely, I find that annotating with @fastmath:

julia> ∂²f(x) = @fastmath derivative(x->derivative(f, x), x)

doesn’t seem to produce the optimised version:

julia> @code_llvm debuginfo=:none ∂²f(1.0)

; Function Attrs: uwtable
define double @"julia_\E2\88\82\C2\B2f_475"(double) #0 {
top:
  %1 = fmul double %0, %0
  %2 = fmul double %0, 2.000000e+00
  %3 = fmul double %1, 3.000000e+00
  %4 = fmul double %2, 3.000000e+00
  %5 = fmul double %3, 0.000000e+00
  %6 = fadd double %4, %5
  ret double %6
}

Is this still expected?

2 Likes

The reason this doesn’t help is that it doesn’t do anything, as the macro can’t see any operations it knows about.

But adding fastmath to the x^3 alone was my first guess, too, and that still doesn’t improve what’s generated. This time because it gets thrown away at the first dispatch, DiffRules defines no methods for FastMath functions, and the fallback is to call the un-fast variant:

julia> @macroexpand @fastmath derivative(x->derivative(f, x), x)
:(derivative((x->begin
              #= REPL[6]:1 =#
              derivative(f, x)
          end), x))

julia> @macroexpand @fastmath x^3
:(Base.FastMath.pow_fast(x, Val{3}()))

julia> @less Base.FastMath.pow_fast(ForwardDiff.Dual(1.0, 1.0), Val{3}()) 
# @inline pow_fast(x, v::Val) = Base.literal_pow(^, x, v)

Edit: the obvious attempt does not lead to an improvement:

julia> using ForwardDiff: Dual, partials, value, derivative

julia> import Base.FastMath: pow_fast, mul_fast

julia> function pow_fast(x::Dual{Z}, ::Val{p}) where {Z,p}
         y = pow_fast(value(x), Val(p))
         dys = map(partials(x).values) do dx
           mul_fast(p, dx, pow_fast(value(x), Val(p-1)))
         end
         Dual{Z}(y, dys)
       end;

julia> g(x) = @fastmath x^3

julia> ∂²g(x) = derivative(x -> derivative(g, x), x)
1 Like

Ah of course, fastmath just sees the symbols.

I’m not sure how to solve this either.

3 Likes

I did not recommend that.
Yingbo does not recommend that; because it is an experiment that is on hiatus.

That info was much more on medium term outlook, than a suggestion.

Are there any other AD libraries that currently support ChainRules.jl?

Reverse mode: Zygote, and there is a branch of Nabla.
Forward and Reverse: ReversePropagation.jl
Reverse mode for some types, but not for rules: Yota.jl

but it doesn’t mean on any of this that this is solved.
Just that there is an option to be

1 Like

It’s worse than that, because Inf * 0.0 is also NaN. Because of that, the code will give an incorrect result when 3x^2 overflows:

julia> ∂²f(1e200)
NaN
3 Likes

For interest:

ChainRules.jl does define fast math rules for FastMath functions.
And the way the code works is really cool.
Even if i do say so myself.
The fast math rules are a code tranform (basically applying @fastmath) of the original rules

Even having those though is not enough always since @fastmath does not apply recursively.
(ideally @fastmath would act on dynamic scope like a Cassette pass, not lexical scope)

2 Likes

Obligatory posting of NaN+Inf computing:

Technically ReversePropogation.jl is only reverse mode. We could (and should) easily make it emit the forward mode. Then when SymbolicUtils.jl supports array symbols, it’ll be a fairly complete system and essentially do this with a switch to apply simplify rules.

4 Likes

It can already emit forward mode, it’s just a bit hidden.

3 Likes