To differentiate a loss function defined in terms of a network’s derivatives seems to have been an issue since forever, specially when Zygote is involved; see [1][2][3] and many more over at Zygote’s git.
In my case, because I am working with exotic architectures, it is not possible to use @ChrisRackauckas’s NeuralPDE library. I have also been unable to reproduce the “ReverseDiff-Over-Zygote” hack that is commonly thrown around in these discussions.
Is there any other experimental AD library that is capable of this? Should I wait for it to be released? Or should I appeal to sketchy finite differences under the hood? That wouldn’t exactly thrill reviewers…
I can’t answer the question, but note that Diffractor.jl has seen >40 commits since Jan 1. That’s more than some packages that you would probably attest are very much alive.
On the flipside, it relies heavily on nightly compiler features and has yet to conquer issues such as Even very basic broadcasting breaks inferability · Issue #147 · JuliaDiff/Diffractor.jl · GitHub. To my knowledge only forward mode is coming any time soon, so even if all that is addressed Diffractor may not be fit for the purpose in this thread. Given the history of promotion around this project (if anyone is curious, search for “PINN” and “Diffractor” on Discourse) and the still completely unspecified timelines, I think it’d be very hard to argue the community is overcompensating on the expectation management front.
Oh, I’m not arguing either way about that. The point was to add more context since a lot of people are interested in Diffractor but don’t know where it currently sits readiness-wise. Seeing a bunch of commit activity in isolation doesn’t provide much signal (positive or negative) for that
I have been doing higher order diff with Zygote. It required redefining rules of some function, because they were not AD friendly (contained mutation), but it was doable. Time to first gradient was long, sometimes even half an hour. But it worked.
Thanks Chris. Is this the route you went for when making NeuralPDE? I have gone over the repo and the associated article a few times but just couldn’t find where and how the numerical derivatives were calculated. It does work brilliantly however, great work.
The point though is that nested AD is not what you want to do here. Even if you can do nested AD, it’s still not really the solution so I don’t know why people keep mentioning it.
No, because it wasn’t ready when we built NeuralPDE, so we did a mixture of numerical and reverse mode to hit the asymptotically optimal form, and will be (over the summer) replacing the numerical parts with this TaylorDiff form to have an optimal all AD solution.
I see. Then what is the intuition behind using TaylorDiff?
Let’s take your PINNs tutorial as a MWE: Flux can peek inside the finite difference net_xFD and train the network:
using Flux, Statistics #, TaylorDiff, ForwardDiff, Statistics, Plots
NN = Chain(Dense(1 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
net(x) = x*first(NN([x]))
net(1) #Works
ϵ = Float32((eps(Float32))^(1/2)) #Naive Finite Difference
net_xFD(x) = (net(x+ϵ)-net(x))/(ϵ)
net_xFD(1) #Works
ts = 1f-2:1f-2:1f0 #Training set and loss function
loss() = mean(abs2(net_xFD(t)-cos(t)) for t in ts)
#Training Loop
opt = Flux.Adam()
data = Iterators.repeated((), 500)
iter = 0
cb = function ()
global iter += 1
if iter % 50 == 0
display(loss())
end
end
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb) #Works
Now, obviously if I just replaced the naive finite difference net_xFD for the TaylorDiff derivative it wouldn’t work. In fact, it seems like TaylorDiff can’t even take derivatives of Flux networks natively:
using TaylorDiff, ForwardDiff
net_xAD(x) = ForwardDiff.derivative(net,x)
net_xAD(1) #Works
net_xTD(x) = TaylorDiff.derivative(net,x,1)
net_xTD(1) #Doesn't work
So how would one go about training using TaylorDiff? Should I forego Flux completely? This is all very confusing.
Yes, I did reverse Zygote over Zygote. The real problem was only the operators. I had to have my own version of logitcrossentropy for example. But it was not bad, I think it was done in less than a day, though I had a prior experience with writing custom rules.
So the above Flux example does actually use nested AD. It has a gradient(loss_by_taylordiff, ...), and loss_by_taylordiff itself calls derivative. Moreover I don’t think there’s anything wrong with that – nested AD is the appropriate thing to do here.
I believe (please correct me if need be) that Chris’ admonition against nested AD, and preference for Taylor-mode AD, is specifically when computing second derivatives directly, e.g. when you’re directly computing some d2y/dt2.
That’s not what @Bizzi’s example appears to require. The second derivative is “indirect”: loss contains a derivative, but must itself also be differentiated.
Assuming I’ve got all that correct – Bizzi, the appropriate (asymptotically correct) thing to do here is exactly what your example above is already doing. Use forward-mode autodiff (ForwardDiff) to compute net_xAD as the input t is a scalar, then use reverse-mode autodiff (Flux) to optimise the overall problem, as the output of loss is a scalar.
Thanks for taking the time to clarify these bits, Patrick.
I’m afraid I don’t understand your last paragraph, however. Whenever I use, say, ForwardDiff to calculate the derivative of the network, it looks like Flux can no longer “look inside” the loss function and see its dependence on the network’s parameters. As a result, the gradient returns 0 and the training loss never decreases. Using the example above:
#PINNs Example ForwardDiff
using Flux, Statistics, ForwardDiff
NN = Chain(Dense(1 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
net(x) = x*first(NN([x]))
net(1) #Works
net_xAD(x) = ForwardDiff.derivative(net,x)
net_xAD(1) #Works
ts = 1f-2:1f-2:1f0
loss() = mean(abs2(net_xAD(t)-cos(t)) for t in ts)
opt = Flux.Adam()
data = Iterators.repeated((), 500)
iter = 0
cb = function ()
global iter += 1
if iter % 50 == 0
display(loss())
end
end
Flux.train!(loss, Flux.params(NN), data, opt; cb=cb) #Runs, but does not decrease the loss
Am I missing something? I feel like this could be related to the @functor macro, but I can’t really see how. Why should the gradient work with the finite differences net_xFD but not with net_xAD? My only explanation for this was that nested AD was not supported.
I guess I should replace my last paragraph with just “Use forward-mode autodiff to compute net_xAD, then use reverse-mode autodiff to optimise the overall problem.” I believe the statements I made about autodiff are correct, and that what you’re seeing here is a Julia bug: that these packages are silently failing to compose.
Honestly, I don’t actually use Julia in my work, as I’ve ran into far too many issues exactly like what you’re seeing here. Try using JAX instead, which has done nested AD for years without difficulty. (Shameless advert.)
The examples Chris posted above avoid this by replacing one or both of the ADs involved with TaylorDiff. You could try other ADs as well (e.g. Enzyme as suggested above). I’ll defer to the people who actually work in this area to make suggestions though.
I see. Although I have considered dropping Julia multiple times at this point, my immense appreciation for the language’s concept drives me to try a little bit more. If I am unable to overcome this issue by the end of the week, however, I am probably going back to Python. If I do, I’ll certainly take a look at JAX.
I see. So this is a limitation of Zygote + ForwardDiff specifically? I’ll try the implementation with TaylorDiff next. I’m aware that Lux is better behaved than Flux for a variety of applications, but it shouldn’t make a difference in this case, correct? (Given that Lux is also built on top of Zygote).