How to force Flux to use FiniteDiff

I have loss() function.

FiniteDiff.finite_difference_derivative - works great with the this loss()

ForwardDiff.derivative - doesn’t work at all.

As result Flux.train!(loss,Flux.params(NNODE),data,ADAM(0.001)) doesn’t work also!!!

How I can push Flux to use FiniteDiff

Thank you in advance.

Well one (potentially very inefficient way) would be sth like this (simple example of the inner part of a custom training loop):

θ, re = Flux.destructure(model)
loss(p) = Flux.mse(re(p)(x), y)
Δθ = FiniteDiff.finite_difference_gradient(loss, θ)
Flux.update!(opt, θ, Δθ)
model = re(θ)

But in general, finite differencing is inefficient and inexact (to the degree that it can be horribly wrong).
So it would probably be a better investment to first figure out why AD didn’t work. What was the error you got?

I presume you tried zygote first, which is the default in flux? Forward diff is not very efficient for gradients. Did you try ReverseDiff?
What were the errors you got?

Thank you for the example, I will try this.

I didn’t try ReverseDiff. I am sorry I am new to Flux. Do you mean, that Flux can use ReverseDiff instead of zygote?

Flux is designed to work with Zygote, so step no. 1 is figuring out if/why your model doesn’t work with Zygote. Generally speaking, switching ADs should be either a measure of last resort or late stage optimization.

Note that both finite differences and forward-mode AD (e.g. ForwardDiff) are generally terrible ways to differentiate loss functions in ML (i.e. scalar functions of many parameters). The problem is that if you have N parameters, the derivative computation essentially costs \sim N evaluations of your loss function, making optimization very expensive if N is large.

In contrast, reverse-mode AD essentially only costs ≈1 additional loss-function evaluation to get the gradient.

The problem is that if you have N parameters, the derivative computation essentially costs ∼N evaluations of your loss function, making optimization very expensive if N is large.

…no? Forward diff would require \lfloor\frac{N}{c}\rfloor evaluations where c is the chunk size. Of course increasing the chunk size has diminishing returns, hence why it’s still bad for large N.

And recent work by the SciML guys has demonstrated that forward-diff is the most efficient method for N<100(-ish)… which can be the case for a number of potential ML and parameter estimation tasks. That’s pretty far from being “generally terrible”.

So really your statement should be that forward-diff is terrible for optimizing neural networks (and not ML in general…), which would be true.

For ML, that would make your largest matrix be 10x10. I don’t think it would be controversial to say that almost all ML usage would be over this limit.

2 Likes

ML != neural networks. So I don’t really agree. But I’ll stop there, I don’t want to divert the thread any further.

Even a 10-dimensional linear regression with 10 data points is pretty small.

1 Like

but those evaluations are \sim c times as expensive as evaluating your loss function for scalar inputs, so overall, the computational cost still scales as \Theta(N) times your loss function.

(Even though each chunk calls your loss function “once” in ForwardDiff.jl, it’s with a multidimensional dual number of size c, which makes arithmetic \sim c times as expensive as for scalar inputs.)

That is, the cost is that of \Theta(N) linearized loss-function evaluations but the constant factor is not exactly 1 only because CPU time ≠ arithmetic count. (Vectorized evaluations can be cheaper than the corresponding number of separate scalar evaluations.)

1 Like

Yes, I did not dispute this. However, you did not state it in terms of asymptoics before but rather in terms of the number of function evaluations, which strictly was not true, or at least somewhat misleading.

What I was disputing was what struck me as an overly pessimistic generalization about forward-mode differentiation’s applicability in ML problems. It is often useful for smaller problems, and you shouldn’t just rule it out ahead of time because you think you’re doing “ML”. I have frequently found it to be more efficient than reverse-mode on long running, medium sized PDE problems because reverse-mode, while asymptotically better, can still have prohibitively high runtime constants that make it slower.

But the context of this post is using Flux with what is presumably a NODE in which case both finite diff and forward diff are not appropriate, you were certainly correct about that.

In an effort to bring this back to the original problem:

My comment was more about reverse differentiation in general, not specific to ReverseDiff.jl.
But in principle, Flux could use ReverseDiff.jl. In pretty much the same way as my little example using FiniteDiff (just replace FiniteDiff.finite_differenc_gradient with ReverseDiff.gradient.

That being said, such an approach would be quite inefficient though, and you would probably only want to do that if your problem requires something that Zygote can’t handle (e.g. mutating arrays) and couldn’t be worked around (e.g. make non-mutating). Even then you’d probably want to use Tracker.jl instead of ReverseDiff.jl

I think people would be quite eager to help if you could provide a minimal (non-)working example that reproduces the error

2 Likes

I didn’t give example, since I am interested in an answer in a general form.

The loss() function has a maximization problem inside it. (I use Opitm.jl). The solution is differentiable, so numerical derivative works fine, but it looks like Zygote can’t find analytical derivative (if I understand it correctly).

In the current case, I can rewrite the loss() function without optimization problem, but it won’t be possible in other cases, so I look for the solution for this general form.

Thank you so much for so many answers and insides.

Is this a theoretical concern or something you’ve tested? I think it would be more productive to figure out why Zygote wouldn’t be able to handle it in practice.

1 Like

The loss() function has a maximization problem inside it. (I use Opitm.jl). The solution is differentiable, so numerical derivative works fine, but it looks like Zygote can’t find analytical derivative (if I understand it correctly).

If I am understanding you correctly that your function that you want to differentiate (i.e. the loss function) has a nested non-linear solve (by Optim.jl), then you can use Zygote.@ignore (carefully) on the solve since this should not be relevant to the gradient computation.

If you are doing a bilevel optimization (optimizing a function that itself solves an optimization problem), you can declare your own rrule (vector–Jacobian product) to tell Zygote how to differentiate it efficiently using the implicit-function theorem. (Basically, you differentiate using the KKT conditions describing your inner optimum.)

In general, AD tools need a bit of “help” whenever the function you are differentiating solves a problem approximately by an iterative method (e.g. Newton iterations for root finding, or iterative optimization algorithms, or adaptive quadrature) — even if AD can analyze the iterations, it will end up wasting a lot of effort trying to exactly differentiate the error in your approximation.

See also Differentiating optimization problem solutions in Julia

3 Likes

Thank you dear @stevengj. I hope it will help me.