I would like to minimize a functional that involves computing partial derivatives.
As a toy problem, I want to find a function u(pt)
that minimizes
∫ ∂ₓ(u)^2
So in this toy case any u
that is constant along the x
direction is a solution. I would like to rediscover such a solution using Flux.jl
. So u
is a neural network, the integral is replaced by a sum over a large number of points and some variant of gradient descent is used to optimize it.
How to implement this? I tried the following, but Zygote
seems not to play nicely ForwardDiff
here:
using Zygote
using ForwardDiff
using Flux
struct DirectionalDerivative{F, V}
f::F
direction::V
end
const DD = DirectionalDerivative
function (dd::DD)(pt)
ForwardDiff.derivative(0) do h
dd.f(pt + h * dd.direction)
end
end
u = Chain(
Dense(2, 20, Flux.sigmoid),
Dense(20, 1, Flux.sigmoid),
)
npts = 10
pts = randn(Float32, 2, npts)
vx = similar(pts)
vx .= 0
vx[1,:] .= 1
du_dx = DD(u, vx)
loss = () -> sum(abs2, du_dx(pts))
Flux.train!(loss, params(u), [()], ADAM())
MethodError: no method matching *(::NamedTuple{(:value, :partials),Tuple{Nothing,Array{Float32,1}}}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4"{DirectionalDerivative{Chain{Tuple{Dense{typeof(σ),Array{Float32,2},Array{Float32,1}},Dense{typeof(σ),Array{Float32,2},Array{Float32,1}}}},Array{Float32,2}},Array{Float32,2}},Int64},Float32,1})
Closest candidates are:
*(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:529
*(!Matched::Bool, ::ForwardDiff.Dual) at /home/jan/.julia/packages/ForwardDiff/cXTw0/src/dual.jl:413
*(!Matched::Complex{Bool}, ::Real) at complex.jl:309
...
Stacktrace:
[1] _broadcast_getindex_evalf at ./broadcast.jl:631 [inlined]
[2] _broadcast_getindex at ./broadcast.jl:604 [inlined]
[3] getindex at ./broadcast.jl:564 [inlined]
[4] copy at ./broadcast.jl:854 [inlined]
[5] materialize at ./broadcast.jl:820 [inlined]
[6] (::Zygote.var"#1725#1726"{Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#3#4"{DirectionalDerivative{Chain{Tuple{Dense{typeof(σ),Array{Float32,2},Array{Float32,1}},Dense{typeof(σ),Array{Float32,2},Array{Float32,1}}}},Array{Float32,2}},Array{Float32,2}},Int64},Float32,1},2}})(::Array{NamedTuple{(:value, :partials),Tuple{Nothing,Array{Float32,1}}},2}) at /home/jan/.julia/packages/Zygote/2JFwA/src/lib/broadcast.jl:96
...