Is it possible to do Nested AD ~elegantly~ in Julia? (PINNs)

To differentiate a loss function defined in terms of a network’s derivatives seems to have been an issue since forever, specially when Zygote is involved; see [1] [2] [3] and many more over at Zygote’s git.

In many of these threads (specially pre 2022) it is said that the release of Diffractor.jl would correct this issue. Now that it seems that Diffractor is mostly dead, what is left?

In my case, because I am working with exotic architectures, it is not possible to use @ChrisRackauckas’s NeuralPDE library. I have also been unable to reproduce the “ReverseDiff-Over-Zygote” hack that is commonly thrown around in these discussions.

Is there any other experimental AD library that is capable of this? Should I wait for it to be released? Or should I appeal to sketchy finite differences under the hood? That wouldn’t exactly thrill reviewers…

I can’t answer the question, but note that Diffractor.jl has seen >40 commits since Jan 1. That’s more than some packages that you would probably attest are very much alive.

4 Likes

Nested AD for this kind of thing is asymptotically much worse than numerical. If you’re going to use anything else, the thing to try is:

On the flipside, it relies heavily on nightly compiler features and has yet to conquer issues such as Even very basic broadcasting breaks inferability · Issue #147 · JuliaDiff/Diffractor.jl · GitHub. To my knowledge only forward mode is coming any time soon, so even if all that is addressed Diffractor may not be fit for the purpose in this thread. Given the history of promotion around this project (if anyone is curious, search for “PINN” and “Diffractor” on Discourse) and the still completely unspecified timelines, I think it’d be very hard to argue the community is overcompensating on the expectation management front.

Digressing a bit, I wonder how much crossover this thread has with Difficulties writing a program that computes PDEs involving Laplacians with AD, which doesn’t appear to have any answers yet.

3 Likes

Sure, I’m not saying it’s near-ready, just that pronouncing something “dead” is very different than the question of “is it ready yet?”

1 Like

Oh, I’m not arguing either way about that. The point was to add more context since a lot of people are interested in Diffractor but don’t know where it currently sits readiness-wise. Seeing a bunch of commit activity in isolation doesn’t provide much signal (positive or negative) for that :slight_smile:

3 Likes

I have been doing higher order diff with Zygote. It required redefining rules of some function, because they were not AD friendly (contained mutation), but it was doable. Time to first gradient was long, sometimes even half an hour. But it worked.

Thanks Chris. Is this the route you went for when making NeuralPDE? I have gone over the repo and the associated article a few times but just couldn’t find where and how the numerical derivatives were calculated. It does work brilliantly however, great work.

That sound quite lengthy. What is the size of the network? Have you tried Reverse-over-Zygote?

Enzyme does nested AD.

5 Likes

The point though is that nested AD is not what you want to do here. Even if you can do nested AD, it’s still not really the solution so I don’t know why people keep mentioning it.

No, because it wasn’t ready when we built NeuralPDE, so we did a mixture of numerical and reverse mode to hit the asymptotically optimal form, and will be (over the summer) replacing the numerical parts with this TaylorDiff form to have an optimal all AD solution.

1 Like

I see. Then what is the intuition behind using TaylorDiff?

Let’s take your PINNs tutorial as a MWE: Flux can peek inside the finite difference net_xFD and train the network:

using Flux, Statistics #, TaylorDiff, ForwardDiff, Statistics, Plots

NN = Chain(Dense(1 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 1))
net(x) = x*first(NN([x]))
net(1) #Works

ϵ = Float32((eps(Float32))^(1/2)) #Naive Finite Difference
net_xFD(x) = (net(x+ϵ)-net(x))/(ϵ)
net_xFD(1) #Works

ts = 1f-2:1f-2:1f0 #Training set and loss function
loss() = mean(abs2(net_xFD(t)-cos(t)) for t in ts) 

#Training Loop
opt = Flux.Adam()
data = Iterators.repeated((), 500)
iter = 0
cb = function ()
  global iter += 1
  if iter % 50 == 0
    display(loss())
  end
end
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb) #Works 

Now, obviously if I just replaced the naive finite difference net_xFD for the TaylorDiff derivative it wouldn’t work. In fact, it seems like TaylorDiff can’t even take derivatives of Flux networks natively:

using TaylorDiff, ForwardDiff

net_xAD(x) = ForwardDiff.derivative(net,x) 
net_xAD(1) #Works

net_xTD(x) = TaylorDiff.derivative(net,x,1) 
net_xTD(1) #Doesn't work

So how would one go about training using TaylorDiff? Should I forego Flux completely? This is all very confusing.

Yes, I did reverse Zygote over Zygote. The real problem was only the operators. I had to have my own version of logitcrossentropy for example. But it was not bad, I think it was done in less than a day, though I had a prior experience with writing custom rules.

Yes use Lux as the tests show.

There’s a PINN example with Flux though:

1 Like

So the above Flux example does actually use nested AD. It has a gradient(loss_by_taylordiff, ...), and loss_by_taylordiff itself calls derivative. Moreover I don’t think there’s anything wrong with that – nested AD is the appropriate thing to do here.

I believe (please correct me if need be) that Chris’ admonition against nested AD, and preference for Taylor-mode AD, is specifically when computing second derivatives directly, e.g. when you’re directly computing some d2y/dt2.

That’s not what @Bizzi’s example appears to require. The second derivative is “indirect”: loss contains a derivative, but must itself also be differentiated.

Assuming I’ve got all that correct – Bizzi, the appropriate (asymptotically correct) thing to do here is exactly what your example above is already doing. Use forward-mode autodiff (ForwardDiff) to compute net_xAD as the input t is a scalar, then use reverse-mode autodiff (Flux) to optimise the overall problem, as the output of loss is a scalar.

6 Likes

Thanks for taking the time to clarify these bits, Patrick.

I’m afraid I don’t understand your last paragraph, however. Whenever I use, say, ForwardDiff to calculate the derivative of the network, it looks like Flux can no longer “look inside” the loss function and see its dependence on the network’s parameters. As a result, the gradient returns 0 and the training loss never decreases. Using the example above:

#PINNs Example ForwardDiff
using Flux, Statistics, ForwardDiff

NN = Chain(Dense(1 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 12,tanh),
           Dense(12 => 1))
net(x) = x*first(NN([x]))
net(1) #Works

net_xAD(x) = ForwardDiff.derivative(net,x) 
net_xAD(1) #Works

ts = 1f-2:1f-2:1f0
loss() = mean(abs2(net_xAD(t)-cos(t)) for t in ts) 

opt = Flux.Adam()
data = Iterators.repeated((), 500)
iter = 0
cb = function ()
  global iter += 1
  if iter % 50 == 0
    display(loss())
  end
end
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb) #Runs, but does not decrease the loss

The displayed losses are:

1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0
1.3544431f0

Am I missing something? I feel like this could be related to the @functor macro, but I can’t really see how. Why should the gradient work with the finite differences net_xFD but not with net_xAD? My only explanation for this was that nested AD was not supported.

I guess I should replace my last paragraph with just “Use forward-mode autodiff to compute net_xAD, then use reverse-mode autodiff to optimise the overall problem.” I believe the statements I made about autodiff are correct, and that what you’re seeing here is a Julia bug: that these packages are silently failing to compose.

Honestly, I don’t actually use Julia in my work, as I’ve ran into far too many issues exactly like what you’re seeing here. Try using JAX instead, which has done nested AD for years without difficulty. (Shameless advert.)

2 Likes

Did you happen to see the warning from Add warnings to ForwardDiff functions by mcabbott · Pull Request #1224 · FluxML/Zygote.jl · GitHub? If not, that’s probably a bug. Patrick’s suspicion is basically right though: Zygote (the AD Flux uses by default) over ForwardDiff will not work for your code as-is. For more on how this affects PINNs specifically, see PINN loss doesn't converge to 0? · Issue #1966 · FluxML/Flux.jl · GitHub.

The examples Chris posted above avoid this by replacing one or both of the ADs involved with TaylorDiff. You could try other ADs as well (e.g. Enzyme as suggested above). I’ll defer to the people who actually work in this area to make suggestions though.

2 Likes

I see. Although I have considered dropping Julia multiple times at this point, my immense appreciation for the language’s concept drives me to try a little bit more. If I am unable to overcome this issue by the end of the week, however, I am probably going back to Python. If I do, I’ll certainly take a look at JAX.

I see. So this is a limitation of Zygote + ForwardDiff specifically? I’ll try the implementation with TaylorDiff next. I’m aware that Lux is better behaved than Flux for a variety of applications, but it shouldn’t make a difference in this case, correct? (Given that Lux is also built on top of Zygote).