How to achieve good performance with Zygote.pushforward on a neural network

Hello,

In fact Zygote.pushforward is not limited to functions with scalar inputs, and it can be used directly on the neural network!

Here is the same network using Zygote.pushforward on the full network. It takes about the same amount of time as the finite difference method:

using Flux
using Statistics

const X = reshape(0:1f-1:10, 1, :)
const Y = sin.(X)

m = Chain(
    Dense(1, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 10, tanh),
    Dense(10, 1),
)

m′(X::AbstractArray) = Flux.pushforward(m, X)(1)

loss(X, Y) = Flux.mse(m(X), Y) + mean(abs2.(cos.(X) - m′(X)))

opt = ADAM()
cb() = @show loss(X, Y)
@time Flux.@epochs 1000 Flux.train!(loss, params(m), [(X, Y)], opt; cb)
1 Like