@vavrines the mixed neural DEs example gives a hint for how to do it: https://github.com/JuliaDiffEq/DiffEqFlux.jl#mixed-neural-des . Essentially the idea is that you need to grab the parameters and manually update the parameters of the neural network. Basically, the secret sauce is that you can use this library function to grab all of the parameters from a neural network like:
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = Flux.data(DiffEqFlux.destructure(ann))
So now here’s a full example. Let’s optimize an ODE which is only partially defined by a neural network. We parameterize the initial condition, the neural network, and the partially defined ODE as follows:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim
tspan = (0.0f0,25.0f0)
save_t = 0.0:0.1:25.0
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float32[-2.0,1.1]
p = [p1;p2;u0]
u0 = Float32[0.8; 0.8]
function dudt_(du,u,p,t)
x, y = u
du[1] = DiffEqFlux.restructure(ann,p[1:41])(u)[1]
du[2] = p[end-1]*y + p[end]*x
Notice that we will need p
for later since p
is the vector of all parameters involved. Now we can define a stub ODEProblem
which we remake
with parameters as part of the loss function:
prob = ODEProblem(dudt_,u0,tspan)
function predict(p)
_prob = remake(prob,u0=p[end-1:end],p=p[1:43])
loss_numerical(p) = sum(abs2,x-1 for x in predict(p))
Each new p
gets split up into the u0
part and the p
part, and then in the ODE the part that’s left reparameterizes the neural network and the rest is for the part of the ODE that is known. Then I just made up a silly loss function that is the L2 loss against 1. Optim.jl can then optimize this using numerical differencing:
result = optimize(loss_numerical, p, BFGS())
In some sense we are done because that’s all that’s necessary, but we can speed that up by utilizing the adjoint to get the gradients. Using the OnceDifferentiable
form from NLSolversBase, we can define a function which computes the loss and the gradient simultaneously. The gradient is given by the adjoint sensitivity analysis defined in the DiffEq docs: http://docs.juliadiffeq.org/latest/analysis/sensitivity.html#Adjoint-Sensitivity-Analysis-1
The code is thus as follows:
using DiffEqSensitivity
dg(out,u,p,t,i) = (out.=(1.0.-u)./2)
function fg!(F, G, x)
_prob = remake(prob,u0=p[end-1:end],p=p[1:43])
sol = solve(_prob,Tsit5(),saveat=save_t)
if !(G == nothing)
res = adjoint_sensitivities(sol,Tsit5(),dg,save_t,sensealg=SensitivityAlg(backsolve=true))
if !(F == nothing)
return sum(abs2,x-1 for x in Array(sol))
loss_adjoint = OnceDifferentiable(Optim.only_fg!(fg!), p)
result = optimize(loss_adjoint, p, BFGS())
This funny model I choose seems to go unstable when using BFGS
but is fine with Flux Tracker. Is it good to use? Who knows: that’s probably a good research topic. Anyways, for reference, the same thing using Flux optimizers directly is:
u0 = param(Float32[0.8; 0.8])
tspan = (0.0f0,25.0f0)
ann = Chain(Dense(2,10,tanh), Dense(10,1))
p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float32[-2.0,1.1]
p3 = param([p1;p2])
ps = Flux.params(p3,u0)
function dudt_(du,u,p,t)
x, y = u
du[1] = DiffEqFlux.restructure(ann,p[1:41])(u)[1]
du[2] = p[end-1]*y + p[end]*x
prob = ODEProblem(dudt_,u0,tspan,p3)
function predict_adjoint()
loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())
data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function ()
# Display the ODE with the current parameter values.
Flux.train!(loss_adjoint, ps, data, opt, cb = cb)