Is it possible to do Nested AD ~elegantly~ in Julia? (PINNs)

Zygote indeed has issues with nested differentiation. Here, another problem seems to be its implicit handling of parameters though. Switching to the new and more explicit API of Optimisers you’re example works with ReverseDiff over ForwardDiff:

``````using Flux, ForwardDiff, ReverseDiff, Zygote, Optimisers, Statistics

NN = Chain(Dense(1 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
θ, reNN = Flux.destructure(NN)

net(x, θ) = x*first(reNN(θ)([x]))
net(1, θ) #Works

net_xAD(x, θ) = ForwardDiff.derivative(x -> net(x, θ),x)

ts = 1f-2:1f-2:1f0
loss(θ) = mean(abs2(net_xAD(t, θ)-cos(t)) for t in ts)

# Both of these work

function train(θ; opt = Optimisers.Adam(), steps = 250)
state = Optimisers.setup(opt, θ)
for i = 1:steps
if i % 50 == 0
display(loss(θ))
end
∇θ = ReverseDiff.gradient(loss, θ)
state, θ = Optimisers.update(state, θ, ∇θ)
end
θ
end
``````
``````julia> train(θ);
0.083571285f0
0.022315795f0
0.0012668703f0
0.00012052046f0
0.00011890126f0
``````

Interestingly, trying Zygote for the outer differentiation explains why it fails:

``````julia> Zygote.gradient(loss, θ)
┌ Warning: `ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
│ typeof(f) = var"#23#24"{Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/oGI57/src/lib/forward.jl:158
(nothing,)
``````
5 Likes

Fantastic. This works perfectly. Making the dependence upon the network’s parameters explicit seems to have made the code both simpler and more readable.

I’ll mark it as an answer so that we can point towards this when this question inevitably shows up again. Thank you all for your time and patience.

1 Like

An update for those who find this thread in the future: It is not advised to implement PINNs this way in Julia. The code posted above runs, but the gradients are incorrect (by a very small amount).

Presently, there seems to be no AD library in Julia that can do (reverse) nested differentiation of closures safely (see this issue). Moreover, even this slightly incorrect approach demands constant reconstruction of structures, which makes it very slow.

As far as I can tell, only ForwardDiff has addressed this problem, but it’s generally too slow for getting the gradients with regards to the parameters. Those looking to implement PINNs in Julia will probably have to do it through some other, more clever approach which does not rely on the intuitive syntax the snippets in this thread try to reproduce.

2 Likes

I would’ve thought TaylorDiff could do it since it was recommended above, but as noted earlier I lack the know-how to provide any guidance. Did you run into any issues trying it in place of Reverse/ForwardDiff?

I’m currently trying TaylorDiff!

It doesn’t seem very robust, at least in my naive implementation, but it might work. The issues are twofold:
a) it does not support most activation functions in NNlib, so one must instead use their non-optimized versions;
b) it doesn’t play nicely with ReverseDiff (but Zygote seems to work when the parameters are stored as a NamedTuple).

Code snippet for a):

``````using Lux, Random, TaylorDiff, Zygote

#Network Setup
n = 16
model =
Chain(Dense(1 => n,tanh),
Dense(n => n,tanh),
Dense(n => 1))
rng = Random.default_rng()
p0, s0 = Lux.setup(rng,model)
x0 = 1f0

net(x,p) = model([x],p,s0) |> first |> first

diff(x,p) = TaylorDiff.derivative(x->net(x,p),x,2)
diff(2x0,p0) # ERROR: MethodError: no method matching tanh_fast(::TaylorScalar{Float32, 3})
``````

Code snippet for b)

``````using Lux, Random, TaylorDiff, Zygote, ReverseDiff, ComponentArrays

#Network Setup
n = 16
model =
Chain(Dense(1 => n,cos),
Dense(n => n,cos),
Dense(n => 1))
rng = Random.default_rng()
p0, s0 = Lux.setup(rng,model)
x0 = 1f0

net(x,p) = model([x],p,s0) |> first |> first

diff(x,p) = TaylorDiff.derivative(x->net(x,p),x,2)
diff(x0,p0) # WORKS

gradZY(x0,p0) #WORKS (only for p0 a NamedTuple, not for p0 a ComponentArray)

pc = p0 |> ComponentArray
gradRV(x0,pc) #ERROR: TypeError: in TrackedReal, in V, expected V<:Real, got Type{TaylorScalar{Float32, 3}}

``````

In any case, I still haven’t figured out if TaylorDiff counts as AD and whether ReverseDiff-over-TaylorDiff or Zygote-over-TaylorDiff would suffer from the issues with nested differentiation of closures I mentioned above. I’ll update this thread with a solution once I’m sure of it.

No closures are required with Lux. PlNN code shouldn’t have any so the thing Zygote is warning about shouldn’t happen with Lux and ComponentArrays. And TaylorDiff is going to be more efficient.

I think the issue here is that NNLib is missing forward mode rules for some of its kernels, so it falls back to differentiating the operations. That’s fine for many cases, but indeed those should get higher level rules. I’m going to chat with @tansongchen about a bit of the coverage here.

1 Like

I have fixed that in a recent PR, you can try installing TaylorDiff.jl from the source.

It’s not hard to add support for those functions

``````using ChainRulesCore, NNlib
function NNlib.tanh_fast(t::TaylorScalar{T,2}) where {T}
t0, t1 = TaylorDiff.value(t)
return TaylorScalar{T,2}(frule((NoTangent(), t1), tanh_fast, t0))
end
function NNlib.tanh_fast(t::TaylorScalar{T,N}) where {T,N}
t1 = TaylorScalar{T,N-1}(t)
df = 1 - tanh_fast(t1)^2
return TaylorDiff.raise(tanh_fast(TaylorDiff.value(t)[1]), df, t)
end
``````

I can confirm that it is perfectly fine to use TaylorDiff for PINNs. the real problem is that when used with Zygote, it is not as fast as finite differences.

2 Likes

Fantastic, thank you both very much.

Although I still don’t quite understand why TaylorDiff succeeds where the other differentiation tools failed, everything seems to indicate that it does work very well for PINNs.

I’ll later benchmark FiniteDifferences over Zygote and compare it to TaylorDiff. What I do know is that FiniteDifferences did not work well over/under ReverseDiff/ForwardDiff.

When I say finite difference, I mean writing the difference formula manually, as done in NeuralPDE.jl, which should work smoothly with any automatic differentiation package.

TaylorDiff.jl has not been heavily tested on GPUs. Finite difference is a much safer option. But you can take your chances.

I’m curious what the architectures refer to here. Neural networks?

1 Like

Note however the vast majority of these have been on Diffractor’s forwards-mode AD.
I/we will likely have more to say about that come JuliaCon. (It’s becoming a really good Forward Mode AD)

2 Likes

There’s kind of two factors here. One is that finite differences reach the asymptotic speed of Taylor-mode forward AD. If you think about say second order, f(x - dx) - 2f(x) + f(x + dx), you have 3 f-evaluations which is essentially matching the compute complexity that you’d get from a Taylor mode optimizing away the extra calculations. Taylor-mode AD in theory could achieve a bit more speed by doing some tricks for force SIMD across the primal and derivative calculations, which is something that ForwardDiff.jl does for 1st derivatives and TaylorDiff.jl does not do, but we’re talking about much smaller optimizations like 2x here. So for this kind of thing, it shouldn’t be surprising that finite differences are really difficult to beat performance-wise: you need a very optimized taylor-mode to make it work.

The second thing about AD is always accuracy. But there is some clear literature around this point with PINNs likehttps://arxiv.org/abs/2110.15832 (it is a bit weird that they wrote a paper on the method NeuralPDE had been using for a few years without mentioning it, but I digress). They describe this other reason quite well, in that finite differencing is essentially smoothing by calculating derivatives with multiple points. With PINNs you have an issue where you need quite dense sampling because you’re only making the PDE derivative “correct” via the loss function at the points you sample. But with finite differencing in these steps, you are quite naturally incorporating the evaluation of the neural network at nearby points into the loss, which seems to make it better converge over the full domain than straight AD. This paper does not clearly highlight the numerical issues that can exist, especially when there is a high stiffness in some operators (or high CFL), which may make the FD form not converge well. We have found a case of that, which is the reason for an option to switch to the double AD form, but it wasn’t found to be so essential for most applications.

So while we can chat about AD details yadayadayada, the point remains in the end that you would expect for the PINN application that numerical differentiation is competitive or even better than AD in terms of performance in terms of convergence over the whole domain, since it runs fast and has other effects that can be helpful for the training process.

Agreed, just write it out. It’s quick to do:

1 Like

Yes that is a strong support for using finite differences and it even has a better convergence rate.

TaylorDiff+Zygote is twice as slow as finite difference in my implementation, for me there doesn’t seem to be a particular reason to use it.

``````julia> @time grad2=gradient(θ->sum((mlp(x .+ [0.0001f0,0.0f0], θ) .- mlp(x .- [0.0001f0,0.0f0], θ))*10000.0f0 ./ 2), ps)[1];
0.176942 seconds (370.02 k allocations: 20.227 MiB)

julia> @time grad1=gradient(θ->sum(TaylorDiff.derivative(c->mlp(c,θ), x, [1.0f0, 0.0f0], Val(2))), ps)[1];
0.362480 seconds (931.16 k allocations: 51.276 MiB, 9.20% gc time)
``````
1 Like

We’re trying to construct some Neural networks that obey the underlying symmetries of some systems of PDEs. Something like Equivariant Networks but applied to Lie Groups. One of the ideas involves making weights across different layers communicate, which means that usual training cycles wouldn’t work.

1 Like

Interesting! Coming from a background in traditional computational methods for fluid dynamics has led to me being biased against finite differences. Still, it seems plausible that in the context of neural networks the issues with instability and catastrophic accumulation of error across time steps are less of a concern.
Even more interesting to me is that the “naive” first order approximation works so well as compared to sophisticated packages like FiniteDifferences.jl. I’ll try it out!

A notable difference is that with neural networks you have mesh-less finite differences.

1 Like

Cool! I look forward to reading your paper on this.

2 Likes

Shall one use 64 bit precision though with finite differences? If I don’t I get substantial discrepancies between exact and approximate derivatives:

``````using Zygote
using Flux
using Random

Random.seed!(1234)

d = 1
u = Chain(
Dense(d => 8, tanh),
Dense(8 => 8, tanh),
Dense(8 => 8, tanh),
Dense(8 => 1)
)
ϵ = 1e-6
_ϵ = inv(first(ϵ[ϵ .!= zero(ϵ)]))
∇u(x) = Zygote.gradient(x -> sum(u(x)),x)[1]
n∇u(x) = (u(x.+ϵ)-u(x.-ϵ))*_ϵ./2
x = ones(Float32,d,1)
∇u(x)
# -0.15050948
n∇u(x)
#-0.17881393
``````

If you’re using 32 bit floats, that’s way too small of an epsilon. You may want to see this introduction to automatic differentiation which describes why in detail:

So notice that:

``````julia> eps(Float32)
1.1920929f-7
``````

which means that your derivative term, when `1e-6` is effectively “holding” the value in its 7th decimal place onwards, but you only have 7 decimal places, and thus the accuracy of your derivative will not be more than one digit. So from quick pen and paper, it’s not surprising it’s different in the second decimal place. You’ll do a bit better then to use half of the digits for the differentiation, or in other words, use epsilon as:

``````julia> sqrt(eps(Float32))
0.00034526698f0
``````

If you really need to use Float32, you may want to use central differencing to improve this a bit (and change epsilon to `cbrt(eps(Float32))`)

1 Like

FWIW - we’ve had success is combining Zygote with Dual and Hyperdual numbers. Eg take the grad of a model wrt parameters and then use Hyperduals to take the laplacian

that’s great… if possible can you please share an example of this?