Hello,
In fact Zygote.pushforward is not limited to functions with scalar inputs, and it can be used directly on the neural network!
Here is the same network using Zygote.pushforward on the full network. It takes about the same amount of time as the finite difference method:
using Flux
using Statistics
const X = reshape(0:1f-1:10, 1, :)
const Y = sin.(X)
m = Chain(
Dense(1, 10, tanh),
Dense(10, 10, tanh),
Dense(10, 10, tanh),
Dense(10, 1),
)
m′(X::AbstractArray) = Flux.pushforward(m, X)(1)
loss(X, Y) = Flux.mse(m(X), Y) + mean(abs2.(cos.(X) - m′(X)))
opt = ADAM()
cb() = @show loss(X, Y)
@time Flux.@epochs 1000 Flux.train!(loss, params(m), [(X, Y)], opt; cb)