Flux, higher order derivatives and forward mode

I would like to minimize a functional that involves computing partial derivatives.
As a toy problem, I want to find a function u(pt) that minimizes

∫ ∂ₓ(u)^2

So in this toy case any u that is constant along the x direction is a solution. I would like to rediscover such a solution using Flux.jl. So u is a neural network, the integral is replaced by a sum over a large number of points and some variant of gradient descent is used to optimize it.

How to implement this? I tried the following, but Zygote seems not to play nicely ForwardDiff here:

using Zygote
using ForwardDiff
using Flux

struct DirectionalDerivative{F, V}
    f::F
    direction::V
end
const DD = DirectionalDerivative

function (dd::DD)(pt)
    ForwardDiff.derivative(0) do h
        dd.f(pt + h * dd.direction)
    end
end

u = Chain(
    Dense(2, 20, Flux.sigmoid),
    Dense(20, 1, Flux.sigmoid),
)

npts = 10
pts = randn(Float32, 2, npts)

vx = similar(pts)
vx .= 0
vx[1,:] .= 1

du_dx = DD(u, vx)
loss = () -> sum(abs2, du_dx(pts))
Flux.train!(loss, params(u), [()], ADAM())
MethodError: no method matching *(::NamedTuple{(:value, :partials),Tuple{Nothing,Array{Float32,1}}}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4"{DirectionalDerivative{Chain{Tuple{Dense{typeof(σ),Array{Float32,2},Array{Float32,1}},Dense{typeof(σ),Array{Float32,2},Array{Float32,1}}}},Array{Float32,2}},Array{Float32,2}},Int64},Float32,1})
Closest candidates are:
  *(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:529
  *(!Matched::Bool, ::ForwardDiff.Dual) at /home/jan/.julia/packages/ForwardDiff/cXTw0/src/dual.jl:413
  *(!Matched::Complex{Bool}, ::Real) at complex.jl:309
  ...

Stacktrace:
 [1] _broadcast_getindex_evalf at ./broadcast.jl:631 [inlined]
 [2] _broadcast_getindex at ./broadcast.jl:604 [inlined]
 [3] getindex at ./broadcast.jl:564 [inlined]
 [4] copy at ./broadcast.jl:854 [inlined]
 [5] materialize at ./broadcast.jl:820 [inlined]
 [6] (::Zygote.var"#1725#1726"{Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4"{DirectionalDerivative{Chain{Tuple{Dense{typeof(σ),Array{Float32,2},Array{Float32,1}},Dense{typeof(σ),Array{Float32,2},Array{Float32,1}}}},Array{Float32,2}},Array{Float32,2}},Int64},Float32,1},2}})(::Array{NamedTuple{(:value, :partials),Tuple{Nothing,Array{Float32,1}}},2}) at /home/jan/.julia/packages/Zygote/2JFwA/src/lib/broadcast.jl:96
 ...

Use sciml_train with Optim optimizers. It has a type fix: DiffEqFlux.jl/DiffEqFlux.jl at master · SciML/DiffEqFlux.jl · GitHub

with workarounds:

to make this work. We’ve been using it in our physics-informed neural networks, like in:

It looks like you’re training some PINNs, so you might want to join our discussion in #diffeq-bridged on the Slack since right now we’re starting up a project automated “from symbolic” PINN training, and from the looks of your Discourse posts you seem interested.

2 Likes

Thanks! I will see if I can make it run using sciml_train and probably continue asking questions on slack. These ZygoteRules for Dual make my head hurt :smiley:

While I do not yet fully understand the motivation for these ZygoteRules I at least get how they correspond to math. You identify Dual with its cotangent space using the pairing product (_ * _).partial instead of the usual “fieldwise scalar product of structs”.

Following the pointers from @ChrisRackauckas this works for me:

using DiffEqFlux
using Optim
using ForwardDiff
using Flux

struct DirectionalDerivative{F, V}
    f::F
    direction::V
end
const DD = DirectionalDerivative

function (dd::DD)(pt)
    let dd=dd
        ForwardDiff.derivative(0) do h
            dd.f(pt + h * dd.direction)
        end
    end
end

u0 = Chain(
    Dense(2, 20, Flux.sigmoid),
    Dense(20, 1, Flux.sigmoid),
)

npts = 100
pts = randn(Float32, 2, npts)

vx = similar(pts)
vx[1,:] .= 1
vx[2,:] .= 0

theta0, u_from_theta = Flux.destructure(u0)
loss = function (theta)
    u = u_from_theta(theta)
    du_dx = DD(u, vx)
    sum(abs2, du_dx(pts))
end

sol = DiffEqFlux.sciml_train(loss, theta0, LBFGS())
# Flux.destructure
u_sol = u_from_theta(sol.minimizer)

u_sol([-10 -5 1 2 3; 1 1 1 1 1])
4 Likes