"static" autodiff

Hey,

Is there a way to have an autodiff system such as forwarddiff “write out” the dual version of a piece of code for further manual optimisation ?

The function to derivate is not that complicated, but derivating it by hand would be a pain. On the other hand, I am pretty sure the strict running of dual numbers through it is suboptimal as the function value and it’s gradients might all share pieces of code, and I would like to be able to optimize it by hand.

Hey there!
The logic behind the ChainRules.jl package sounds vaguely related: it allows you to write rules that fuse primal and gradient computations to save memory and CPU. But from what I understand you want to get the rule automatically? I don’t know what your function looks like, but if it’s not too loopy, maybe Symbolics.jl can help?

Hahaha you nailed it. It is really loopy, and Symbolics.jl cannot handle it.

What I want is to exploit chain rules already included into AD systems to write code computing the gradient. Then, by hand, I could curate the code and reason about it. Give me a sec I’ll write an exemple for you.

@gdalle This might give you a clearer view of what i am looking for. Consider the following function:

function f(x,t)
    r = zero(eltype(x))
    n = length(x)
    for i in eachindex(x)
        r += exp(x[i]*t)
    end
    r = log(r)
    return r
end

x = rand(10)
t = 3.0
f(x,t)

Say i want tthe gradient of f w.r.t. x:

using ForwardDiff
g(x,t) = ForwardDiff.gradient(y -> f(y,t), x)

g(x,t) # works correctly; 

What I want is an automatic mechanisme that will output the following code :

# Wanted output : 
function ∂f_∂x(x,t)
    r = zero(eltype(x))
    r_dual = zero(r)
    n = length(x)
    for i in eachindex(x)
        r += exp(x[i]*t)
        r_dual += something(x,t)
    end
    r = log(r)
    r_dual = something_else(r)
    return r, r_dual
end

So that i can then curate by hand the result. My bet is that, for very loopy code, there will be opportunities for hand curating and improving efficiency of the gradient computation.

Even if the obtained code is cluttered, de-cluttering by hand would still be worth it in my application case.

1 Like

Right, that’s what I thought, and for your use case I probably don’t know the answer. Re-interpreting the code in that way is basically what most autodiff libraries do under the hood, but I’m not aware of any that lets you retrieve the gradient computation explicitly

Yes, this is what they do but they don’t output this code, they compile it and run it (which is fine). I somehow want to interrupt this pipeline in the middle to take a look myself. I am talking about forwarddiff because i think that forward mode will be easier to reason about for me: the goal is not to produce an efficient code automatically, but to produce as little cluttered code as possible, since I will then benchmark it and modify it by hand to make it efficient.

Edit : if it does not exist and I / someone ends up building it, a funny name would be ManualDiff.jl :wink:

1 Like

The hard bit is loops.
You basically can’t generate loops from a tracking approach (including ForwardDiff.jl, and Jax)
Because it records the operation that run. Not the code that ran it.

So you get statically unrolled list.
Which is fine if you don’t have dynamic control flow

There used to be code for doing this with Zygote.
But Mike gave up in the end.

I think Tapenade can do this for C++ and Fortran.

Thanks @oxinabox for taking the time.

I did not know about Tapenade, this stuff is beautifull.

Maybe for the loop problem, a good way of doing it would be feed the “tool” a function without the loop :

function f(x,t)
    r = zero(eltype(x))
    n = length(x)
    # for i in eachindex(x)
        r += exp(x[i]*t)
    # end
    r = log(r)
    return r
end

have it produce :

function ∂f_∂x(x,t)
    r = zero(eltype(x))
    r_dual = zero(r)
    n = length(x)
    # for i in eachindex(x)
        r += exp(x[i]*t)
        r_dual += something(x,t)
    # end
    r = log(r)
    r_dual = something_else(r)
    return r, r_dual
end

and then re-add the loop manually. So loops might not be my biggest problem if I can get back code from the AD system.

Maybe there are edge cases where this approaches wont work (cant think of one right now), but in my case that would be enough.

I don’t understand what you mean here.

In general using functional constructs like map and fold are better.

Well, please allow me to try to reformulate. My goal is to obtain the code. If the loops are a problem, I can simply remove the loops by comminting them out, even if this completely changes the meaning of the function.

If the emplacement of the (now removed) loops starts and stops are still noted in the code, i could re-add them by hand after while curating the obtained result.

It’s not the loops themselves.
Its how that interacts with tracing
Tracing doesn’t record the code, only the operations that are run.
So if the operations that are run change depending on the values of the input then you have a problem since your trace will be meaning less.

If on the other hand you do a source code transformation AD, like Zygote or Tapenade you can do this.
But it is much harder to write.
though it can be done.

I did just remember about @dfdx 's XGrad.jl which does a lot like what you want.
and does work via source code transformation

1 Like

Yes, this is why people use reverse mode for gradients. When you have a single output and many (n) inputs, forward mode costs roughly n times the cost of computing the output, whereas reverse mode costs roughly O(1) times.

But by the same token, the forward-mode derivative calculation is not a good starting point for deriving the reverse mode / “adjoint” calculation. (See also our matrix calculus course notes.)

1 Like

@stevengj Yes what you are saying makes a lot of sense.

Anyway, I could probably do something correct using pen & paper in a few days work, so if implementing something to do it for me is really hard as the discussion with @oxinabox suggests, maybe it is not worth it.

I’ll still check Xgrad.jl as it seems really cool !

If you like XGrad.jl, you probably want to take a look at Yota.jl - same ideas, but evolved. None of them support loops with non-static number of iterations though.