How to achieve good performance with Zygote.pushforward on a neural network

Hello,

I am trying to implement a neural network with a loss function that contains a derivative of the network with regards to its input.

To be able to use Flux to update the weights of the network, the loss function needs to be differentiable by Zygote. This means that the derivative of the network with regards to its input, computed in the loss function, needs to be differentiable by Zygote.

The easiest way to achieve this is using the finite difference method to compute the derivative of the neural network in the loss function. However there is a new method that was added recently to Zygote, called pushforward, which allows taking a derivative of julia code using forward-mode automatic differentiation. This method can be used here because the code it produces is again differentiable by Zygote.

My problem is the following: I tried to implement a physics-informed neural network using Zygote.pusforward, and I haven’t been able to get good performance with it.

Minimal working example

A physics-informed neural network tries to approximate the sinus function. The loss function is the mean squared error with an additional term that is equal to zero when the derivative of the network is equal to the cosinus (the derivative of the sinus).

The network using the finite difference method takes about 2.8 seconds to train:

using Flux
using Statistics

const X = reshape(0:1f-1:10, 1, :)
const Y = sin.(X)

m = Chain(
    Dense(1, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 1),
)

function m′(X::AbstractArray{T}) where T
    Δ = √eps(T)
    (m(X .+ Δ) - m(X)) / Δ
end

loss(X, Y) = Flux.mse(m(X), Y) + mean(abs2.(cos.(X) - m′(X)))

opt = ADAM()
cb() = @show loss(X, Y)
@time Flux.@epochs 1000 Flux.train!(loss, params(m), [(X, Y)], opt; cb)

The network using Zygote.pushforward takes about 2 minutes to train:

using Flux
using Statistics

const X = reshape(0:1f-1:10, 1, :)
const Y = sin.(X)

m = Chain(
    Dense(1, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 1),
)

scalar_m(x) = first(m([x]))
scalar_m′(x) = Flux.pushforward(scalar_m, x)(1)
m′(X) = scalar_m′.(X)

loss(X, Y) = Flux.mse(m(X), Y) + mean(abs2.(cos.(X) - m′(X)))

opt = ADAM()
cb() = @show loss(X, Y)
@time Flux.@epochs 1000 Flux.train!(loss, params(m), [(X, Y)], opt; cb)

Is there any way I could improve my code to get at least similar performance between using the finite difference method and using Zygote.pushforward?

If this is not currently possible, this is not too bad as there seems to be a new automatic differentiation system in the works, as mentioned by Chris Rackauckas in this similar issue in the NeuralPDE.jl repository. However I would love to know if this is just a problem in my code.

Thanks.

2 Likes

Hello,

In fact Zygote.pushforward is not limited to functions with scalar inputs, and it can be used directly on the neural network!

Here is the same network using Zygote.pushforward on the full network. It takes about the same amount of time as the finite difference method:

using Flux
using Statistics

const X = reshape(0:1f-1:10, 1, :)
const Y = sin.(X)

m = Chain(
    Dense(1, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 1),
)

m′(X::AbstractArray) = Flux.pushforward(m, X)(1)

loss(X, Y) = Flux.mse(m(X), Y) + mean(abs2.(cos.(X) - m′(X)))

opt = ADAM()
cb() = @show loss(X, Y)
@time Flux.@epochs 1000 Flux.train!(loss, params(m), [(X, Y)], opt; cb)
1 Like