Neural ODE fitting really slow

I’m trying to fit this simple neural ODE, first with adam then bfgs, but the performance is extremely slow (about 15 minutes for 100 adams iterations). I am basing the code on the SciML talk from JuliaCon 2020. When I time the loss function with @time loss( p ), it seems to do well, in the order of 10ms. How can I make this faster?

using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, DiffEqSensitivity, Optim
using DiffEqFlux, Flux
using Plots
gr()

function triNN!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ,nₓ,n₃ = u
    k₁,k₋₁,k₂,k₋₂ = cons
    mᵣ,mₗ,mₓ,A = dens
    z = L(u,p)
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ + z[1]
    du[2] = dnₓ = A*k₂*mₓ*mₗ - k₋₂*nₓ + z[2]
    du[3] = dn₃ = z[1] + z[2] 

end

L = FastChain(FastDense(3, 20, tanh),FastDense(20, 20, tanh), FastDense(20, 2))
p = initial_params(L)


# Define the experimental parameter
tspan = (0.0,17.0)
u0 = Float32[0.0,0.0,0.0]

Xₙ = n_tc1 #load data

f = (du,u,p,t) -> triNN!(du,u, p,t,densities_tc1,cons_tc)
prob_nn = ODEProblem(f,u0, tspan, p)
sol_nn = concrete_solve(prob_nn, Tsit5(), u0, p, saveat = t)
# sol_nn_cont = concrete_solve(prob_nn, Tsit5(), u0, p, saveat = 0.1)

plot(sol_nn)

function predict(θ)

    tmp_prob = remake(prob_nn,u0=u0,p=θ)
    tmp_sol =  solve(tmp_prob, Tsit5(), saveat = t,
                  abstol=1e-5, reltol=1e-5,
                  # sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))
                  sensealg = ReverseDiffAdjoint())

    Σ_sol = sum(Array(tmp_sol),dims=1)
end

# No regularisation right now
function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end

# Test
@time loss(p)

const losses = []

callback(θ,l,pred) = begin
    push!(losses, l)
    @show l
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

# First train with ADAM for better convergence
res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.5), cb=callback, maxiters = 50)

# Train with BFGS
res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01),
                                cb=callback, maxiters = 10000)

Thanks!

This is a case where you should use ReverseDiffVJP(true) to compile the tape. You probably don’t want to use ReverseDiffAdjoint unless you really need a discrete adjoint (I am not sure you need that here?). Also try something like a backsolve.

Thanks for your prompt reply. I just tried it with
sensealg = InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))
but it seems to be similarly slow. Could it just be my computer?
Also, I would like to try backsolve, but am not sure what you mean. Is that just backsolve=true keyword or do you mean sensealg=BacksolveAdjoint() ?

BacksolveAdjoint.

Note that not all equations will train at the same speed of course, this has more parameters and a longer time span that when I was demonstrating the training on in the video, but this still seems a lot higher than I would expect on this kind of equation. Is it stiff? Your example doesn’t include the data so I can’t actually run it to profile. Is there some baseline speed you’re expecting?

Hi, thanks again for the help, I made the change you suggested and added data to the code so it should be self contained now. Is there any speed benchmark based on the size of the NN that I could reference?

# using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, DiffEqSensitivity, Optim
using DiffEqFlux, Flux
using Plots
gr()

function triNN!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ,nₓ,n₃ = u
    k₁,k₋₁,k₂,k₋₂ = cons
    mᵣ,mₗ,mₓ,A = dens
    z = L(u,p)
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ + z[1]
    du[2] = dnₓ = A*k₂*mₓ*mₗ - k₋₂*nₓ + z[2]
    du[3] = dn₃ = z[1] + z[2] 

end
# L = FastChain(FastDense(3, 32, tanh),FastDense(32, 32, tanh), FastDense(32, 4))
L = FastChain(FastDense(3, 20, tanh),FastDense(20, 20, tanh), FastDense(20, 2))
p = initial_params(L)


# Define the experimental parameter
tspan = (0.0,17.0)
u0 = Float32[0.0,0.0,0.0]

# Xₙ = n_tc1 #load data
Xₙ = [0.02020270731751947
 0.0
 0.06187540371808745
 0.10536051565782635
 0.083381608939051
 0.19845093872383823
 0.3011050927839216
 0.3011050927839216
 0.3566749439387324
 0.38566248081198473
 0.5108256237659907
 0.6931471805599453
 0.8209805520698303
 0.7339691750802005
 0.8209805520698303
 0.7339691750802005]
 t = [0.25
  0.25
  0.5
  0.5
  1.0
  1.0
  2.0
  2.0
  4.0
  4.0
  6.0
  6.0
  8.0
  8.0
 16.0
 16.0]
densities_tc1 = [15.0,38.0,10.0,1.0]
cons_tc = [ 0.0006692541890287495,0.8662361534770547,1.169965568192585e-6,0.4]

f = (du,u,p,t) -> triNN!(du,u, p,t,densities_tc1,cons_tc)
prob_nn = ODEProblem(f,u0, tspan, p)
sol_nn = concrete_solve(prob_nn, Tsit5(), u0, p, saveat = t)

# plot(solution)
plot(sol_nn)
# summ = reduce(vcat,sum(sol_nn,dims=1))
# h = plot!(sol_nn.t,summ,linecolor=:black)

function predict(θ)

    tmp_prob = remake(prob_nn,u0=u0,p=θ)
    tmp_sol =  solve(tmp_prob, Tsit5(), saveat = t,
                  abstol=1e-5, reltol=1e-5,
                  sensealg = sensealg=BacksolveAdjoint())
                  # backsolve=true)
                  
    Σ_sol = sum(Array(tmp_sol),dims=1)
end

function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end

# Test
@time loss(p)

const losses = []

callback(θ,l,pred) = begin
    push!(losses, l)
    @show l
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

# First train with ADAM for better convergence
res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.1), cb=callback, maxiters = 50)

# Train with BFGS
res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01),
                                cb=callback, maxiters = 1000)

println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")

# Plot the data and the approximation
NNsolution = predict(res2.minimizer)
# Trained on noisy data vs real solution
plot(t, NNsolution')
scatter!(t, Xₙ)


Also how come the ode solvers are not interchangable? If I switch from Tsit5 to Rodas4 or Rosenbrock23 (to take care of possible stiffness), then it errors out:
TypeError: in typeassert, expected Float32, got ForwardDiff.Dual{Nothing,Float32,12}

What about TRBDF2 or KenCarp4?

I still get the same TypeError: in typeassert, expected Float32, got ForwardDiff.Dual{Nothing,Float32,12} when I use TRBDF2 or KenCarp4 instead.

Hey, sorry for the late response but I finally found the time to have a look at this. Were you counting compile time? I did a few tweaks, mainly just changing to VCABM and not choosing a sensealg (InterpolatingAdjoint default was a bit faster) and it’s doing 100 iterations of ADAM in <10 seconds:

9.231829 seconds (14.17 M allocations: 1.611 GiB, 3.69% gc time)

The original code wasn’t much further from that, so I assume you were precompiling packages or something?

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, DiffEqSensitivity, Optim
using DiffEqFlux, Flux
using Plots
gr()

function triNN!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ,nₓ,n₃ = u
    k₁,k₋₁,k₂,k₋₂ = cons
    mᵣ,mₗ,mₓ,A = dens
    z = L(u,p)
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ + z[1]
    du[2] = dnₓ = A*k₂*mₓ*mₗ - k₋₂*nₓ + z[2]
    du[3] = dn₃ = z[1] + z[2]

end
# L = FastChain(FastDense(3, 32, tanh),FastDense(32, 32, tanh), FastDense(32, 4))
L = FastChain(FastDense(3, 20, tanh),FastDense(20, 20, tanh), FastDense(20, 2))
p = initial_params(L)


# Define the experimental parameter
tspan = (0.0,17.0)
u0 = Float32[0.0,0.0,0.0]

# Xₙ = n_tc1 #load data
Xₙ = [0.02020270731751947
 0.0
 0.06187540371808745
 0.10536051565782635
 0.083381608939051
 0.19845093872383823
 0.3011050927839216
 0.3011050927839216
 0.3566749439387324
 0.38566248081198473
 0.5108256237659907
 0.6931471805599453
 0.8209805520698303
 0.7339691750802005
 0.8209805520698303
 0.7339691750802005]
 t = [0.25
  0.25
  0.5
  0.5
  1.0
  1.0
  2.0
  2.0
  4.0
  4.0
  6.0
  6.0
  8.0
  8.0
 16.0
 16.0]
densities_tc1 = [15.0,38.0,10.0,1.0]
cons_tc = [ 0.0006692541890287495,0.8662361534770547,1.169965568192585e-6,0.4]

f = (du,u,p,t) -> triNN!(du,u, p,t,densities_tc1,cons_tc)
prob_nn = ODEProblem(f,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(), saveat = t)

# plot(solution)
plot(sol_nn)
# summ = reduce(vcat,sum(sol_nn,dims=1))
# h = plot!(sol_nn.t,summ,linecolor=:black)

function predict(θ)

    tmp_prob = remake(prob_nn,u0=u0,p=θ)
    tmp_sol =  solve(tmp_prob, VCABM(), saveat = t,
                  abstol=1e-5, reltol=1e-5)
                  # backsolve=true)

    Σ_sol = sum(Array(tmp_sol),dims=1)
end

function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end

# Test
@time loss(p)

const losses = []

callback(θ,l,pred) = begin
    #push!(losses, l)
    #@show l
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

# First train with ADAM for better convergence
@time res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.05), cb=callback, maxiters = 100)

and as expected, things like TRBDF2 and KenCarp4 just worked out of the box:

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, DiffEqSensitivity, Optim
using DiffEqFlux, Flux
using Plots
gr()

function triNN!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ,nₓ,n₃ = u
    k₁,k₋₁,k₂,k₋₂ = cons
    mᵣ,mₗ,mₓ,A = dens
    z = L(u,p)
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ + z[1]
    du[2] = dnₓ = A*k₂*mₓ*mₗ - k₋₂*nₓ + z[2]
    du[3] = dn₃ = z[1] + z[2]

end
# L = FastChain(FastDense(3, 32, tanh),FastDense(32, 32, tanh), FastDense(32, 4))
L = FastChain(FastDense(3, 20, tanh),FastDense(20, 20, tanh), FastDense(20, 2))
p = initial_params(L)


# Define the experimental parameter
tspan = (0.0,17.0)
u0 = Float32[0.0,0.0,0.0]

# Xₙ = n_tc1 #load data
Xₙ = [0.02020270731751947
 0.0
 0.06187540371808745
 0.10536051565782635
 0.083381608939051
 0.19845093872383823
 0.3011050927839216
 0.3011050927839216
 0.3566749439387324
 0.38566248081198473
 0.5108256237659907
 0.6931471805599453
 0.8209805520698303
 0.7339691750802005
 0.8209805520698303
 0.7339691750802005]
 t = [0.25
  0.25
  0.5
  0.5
  1.0
  1.0
  2.0
  2.0
  4.0
  4.0
  6.0
  6.0
  8.0
  8.0
 16.0
 16.0]
densities_tc1 = [15.0,38.0,10.0,1.0]
cons_tc = [ 0.0006692541890287495,0.8662361534770547,1.169965568192585e-6,0.4]

f = (du,u,p,t) -> triNN!(du,u, p,t,densities_tc1,cons_tc)
prob_nn = ODEProblem(f,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(), saveat = t)

# plot(solution)
plot(sol_nn)
# summ = reduce(vcat,sum(sol_nn,dims=1))
# h = plot!(sol_nn.t,summ,linecolor=:black)

function predict(θ)

    tmp_prob = remake(prob_nn,u0=u0,p=θ)
    tmp_sol =  solve(tmp_prob, TRBDF2(), saveat = t,
                  abstol=1e-5, reltol=1e-5)
                  # backsolve=true)

    Σ_sol = sum(Array(tmp_sol),dims=1)
end

function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred), pred
end

# Test
@time loss(p)

const losses = []

callback(θ,l,pred) = begin
    push!(losses, l)
    @show l
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

# First train with ADAM for better convergence
@time res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.05), cb=callback, maxiters = 100)

(but of course it’s not very fast because the equation isn’t stiff). Are you not using a recent version of Julia, i.e. Julia v1.3+? Is your DiffEqFlux version something like DiffEqFlux v1.21.0? I think you must be just using an ancient version of something.

BTW, your loss function was implemented incorrectly. sum(Array(tmp_sol),dims=1) gives back a row vector, so if you broadcast with a vector you get a matrix, not the squared error. For example:

julia> collect(1:10) .* collect(1:10)'
10×10 Array{Int64,2}:
  1   2   3   4   5   6   7   8   9   10
  2   4   6   8  10  12  14  16  18   20
  3   6   9  12  15  18  21  24  27   30
  4   8  12  16  20  24  28  32  36   40
  5  10  15  20  25  30  35  40  45   50
  6  12  18  24  30  36  42  48  54   60
  7  14  21  28  35  42  49  56  63   70
  8  16  24  32  40  48  56  64  72   80
  9  18  27  36  45  54  63  72  81   90
 10  20  30  40  50  60  70  80  90  100

So I fixed your example by transposing it to a vector, and fixed it so that you’d get double saves everywhere:

using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, DiffEqSensitivity, Optim
using DiffEqFlux, Flux
using Plots
gr()

function triNN!(du,u,p,t,dens,cons)
    # unpack rates and constants
    nᵣ,nₓ,n₃ = u
    k₁,k₋₁,k₂,k₋₂ = cons
    mᵣ,mₗ,mₓ,A = dens
    z = L(u,p)
    # model
    du[1] = dnᵣ = A*k₁*mᵣ*mₗ - k₋₁*nᵣ + z[1]
    du[2] = dnₓ = A*k₂*mₓ*mₗ - k₋₂*nₓ + z[2]
    du[3] = dn₃ = z[1] + z[2]

end
L = FastChain(FastDense(3, 32, tanh),FastDense(32, 32, tanh), FastDense(32, 4))
#L = FastChain(FastDense(3, 20, tanh),FastDense(20, 20, tanh), FastDense(20, 2))
p = initial_params(L)

# Define the experimental parameter
tspan = (0.0,16.1) # A little bit longer because there are two values to save at final t?
u0 = Float32[0.0,0.0,0.0]

# Xₙ = n_tc1 #load data
Xₙ = [0.02020270731751947
 0.0
 0.06187540371808745
 0.10536051565782635
 0.083381608939051
 0.19845093872383823
 0.3011050927839216
 0.3011050927839216
 0.3566749439387324
 0.38566248081198473
 0.5108256237659907
 0.6931471805599453
 0.8209805520698303
 0.7339691750802005
 0.8209805520698303
 0.7339691750802005]
 t = [0.25
  0.25
  0.5
  0.5
  1.0
  1.0
  2.0
  2.0
  4.0
  4.0
  6.0
  6.0
  8.0
  8.0
 16.0
 16.0]
densities_tc1 = [15.0,38.0,10.0,1.0]
cons_tc = [ 0.0006692541890287495,0.8662361534770547,1.169965568192585e-6,0.4]

f = (du,u,p,t) -> triNN!(du,u, p,t,densities_tc1,cons_tc)
prob_nn = ODEProblem(f,u0, tspan, p)
sol_nn = solve(prob_nn, Tsit5(), saveat = t)

# plot(solution)
plot(sol_nn)
# summ = reduce(vcat,sum(sol_nn,dims=1))
# h = plot!(sol_nn.t,summ,linecolor=:black)

function predict(θ)

    tmp_prob = remake(prob_nn,u0=u0,p=θ)
    tmp_sol =  solve(tmp_prob, VCABM(), saveat = t,
                  abstol=1e-5, reltol=1e-5)
                  # backsolve=true)
    Σ_sol = sum(Array(tmp_sol),dims=1) # Note: this returns a row vector!
end

function loss(θ)
    pred = predict(θ)
    sum(abs2, Xₙ .- pred'), pred
end

# Test
@time loss(p)

const losses = []

callback(θ,l,pred) = begin
    push!(losses, l)
    @show l
    if length(losses)%50==0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    p = plot(t, pred')
    scatter!(p, t, Xₙ)
    display(p)
    false
end

# First train with ADAM for better convergence
@time res1 = DiffEqFlux.sciml_train(loss, p, ADAM(0.03), cb=callback, maxiters = 10)

# Train with BFGS
res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, BFGS(initial_stepnorm=0.01),
                                cb=callback, maxiters = 1000)

println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Plot the losses
plot(losses, yaxis = :log, xaxis = :log, xlabel = "Iterations", ylabel = "Loss")

# Plot the data and the approximation
NNsolution = predict(res2.minimizer)
# Trained on noisy data vs real solution
plot(t, NNsolution')
scatter!(t, Xₙ)

It finds a solution in like 5 seconds

finalfit

and you can tell that, given the structural constraints you’ve imposed, that’s fairly optimal because that curve shape has to exist. The final loss was:

l = 0.15476138182324645

and in case you wanted the parameters:

julia> println(res2.minimizer)
Float32[0.12764063, -0.51904184, 0.19782981, 0.3315497, 0.13397248, -0.21715562, 0.02866412, -0.30318484, 0.3065577, 0.13548774, 0.04629074, -0.15954407, 0.31648758, -0.015262135, -0.10808937, -0.10684164, -0.0474205, -0.04280915, 0.16332868, 0.3572281, 0.07582084, 0.07100157, 0.1650918, -0.15487944, 0.33461252, 0.094434716, 0.514186, -0.29102844, -0.008031259, 0.08430971, -0.2531316, -0.023445265, 0.44810703, -0.17825657, -0.4443464, -0.08810473, -0.4751105, 0.2581244, -0.21277674, 0.071047425, -0.26755327, 0.43826663, -0.30399424, -0.09008295, 0.122793354, -0.41176003, 0.18107809, -0.12313707, -0.22901058, -0.466065, 0.31890023, 0.23187979, -0.14291202, 0.24903817, -0.35219705, 0.06834008, 0.48332825, -0.19518521, 0.45203757, 0.1520131, -0.2421104, -0.030290017, 0.045389462, -0.348286, 0.093995035, -0.2520377, 0.084871806, 0.3512352, 0.01625438, 0.11524874, 0.13903937, 0.008842226, -0.25907663, -0.12792666, 0.4053956, -0.07770932, -0.03037885, -0.15427223, -0.0040439013, 0.45078498, 0.51911587, -0.4092807, 0.41359138, -0.021582456, -0.03709335, -0.23641047, -0.21546419, 0.55139935, 0.08197822, 0.15208867, 0.14941044, 0.5192426, -0.12801008, -0.4846889, 0.09696112, 0.0044896062, 0.11075601, -0.12915611, -0.080162376, 0.055965576, -0.16089696, 0.06892738, 0.060303405, -0.08030645, 0.18070206, 0.08103386, 0.038639612, 0.077904075, -0.07786285, -0.10607368, 0.07267877, 0.03547609, 0.12903398, 0.007945972, 0.05227977, 0.09987849, 0.10491361, 0.1564354, -0.10742239, 0.11402035, 0.06353248, -0.07932009, 0.14895174, 0.049661007, 0.15181251, -0.11095221, 0.07396178, -0.1417276, -0.2214894, 0.26904988, -0.0047135903, 0.2025713, -0.078114085, 0.07920332, 0.011230715, 0.049954537, -0.35861844, 0.1333887, -0.22364241, -0.054590687, 0.24059054, -0.085329324, -0.28976697, 0.16571443, 0.047707636, 0.03452969, 0.034991704, 0.09826056, 0.30632654, -0.37078938, 0.33247164, -0.26286575, 0.2592442, 0.31244162, 0.24888079, -0.20151804, -0.056455415, -0.10938111, -0.29357794, -0.39555904, 0.08915791, -0.2672956, 0.07299817, 0.35309166, -0.01598734, 0.040030736, 0.11953154, -0.14113928, 0.26021478, 0.20126462, 0.09602096, 0.11820305, -0.23482758, -0.3683619, -0.15234233, 0.23183763, -0.07527195, 0.2524045, -0.14825247, 0.37072745, -0.14319427, 0.34842438, -0.038710117, 0.36170238, -0.17799227, 0.023478188, -0.22921033, 0.30302468, 0.30983692, 0.19196364, 0.23539026, 0.30347046, -0.17456244, -0.1521192, 0.31875795, -0.37635672, 0.026462799, -0.12960283, -0.022344517, 0.121768236, 0.11277555, 0.23235708, 0.08553905, -0.2059214, -0.06169973, 0.36167932, -0.10620528, -0.13227624, 0.17628556, -0.17612197, -0.09457604, -0.20363939, -0.09253764, 0.20489523, 0.24962926, -0.35462168, -0.18084237, -0.18987122, 0.2855642, -0.33114037, -0.05087434, 0.3492996, -0.1292126, 0.22317275, -0.033160653, 0.35011202, 0.10996813, 0.15145105, -0.3374461, 0.046763107, -0.0320004, -0.37957016, -0.11821983, -0.12598577, -0.19373852, 0.029072985, 0.3057559, -0.072463244, -0.25005138, -0.2909804, -0.14207987, 0.040344898, 0.07213363, -0.14318444, 0.06205844, 0.09103409, 0.055981405, -0.17914492, -0.18740852, 0.07841523, -0.19152726, -0.31005484, -0.18299955, -0.0748914, -0.1940295, -0.14914283, -0.15169735, 0.20112877, 0.15159573, -0.16079667, 0.008982445, -0.103261955, -0.01475296, 0.16622765, 0.11757536, 0.28897172, -0.37728614, 0.27725914, -0.21720566, 0.025802115, -0.42453575, 0.10101333, 0.044375867, 0.16483949, -0.10538648, 0.21162663, -0.16095318, 0.073193304, 0.04701516, -0.27181908, 0.19436733, 0.35291308, 0.28157625, -0.3763109, -0.22234863, 0.17517066, -0.18924646, -0.2692099, 0.23769835, -0.14674705, 0.18419321, 0.275707, -0.13655876, -0.39617765, 0.2287443, 0.056185883, 0.018949054, -0.07003914, 0.1799526, -0.08725843, -0.122439146, -0.31838706, 0.09261438, 0.0063319304, -0.11891576, -0.21257448, 0.10406629, 0.17203847, -0.35979, -0.059775967, -0.17184964, 0.12656574, 0.11042608, -0.13298681, 0.03482386, 0.011754372, -0.17818744, 0.0011383872, -0.012356292, -0.024153639, 0.25954208, -0.22242348, -0.25004616, 0.2700861, -0.1757636, -0.2614088, 0.049063366, 0.08083176, -0.14413643, -0.29307604, -0.22606035, -0.2759965, 0.13527209, -0.2321874, 0.24789064, -0.13482733, 0.061353493, 0.3292204, -0.28952435, 0.061915778, -0.1043062, 0.004408786, 0.14001296, 0.048512474, -0.2912039, 0.1703778, 0.0104383305, 0.121538684, 0.23159824, -0.011427064, -0.2220551, -0.0066278432, 0.09692748, 0.028959846, -0.03596206, -0.16218236, -0.1695371, 0.16115716, -0.03920116, 0.10824167, -0.08551167, 0.046640366, -0.062175497, -0.13558151, -0.107566744, 0.15765233, -0.2614961, -0.4397671, 0.07283186, -0.01947967, 0.08140747, -0.3363905, 0.04541828, 0.010307739, -0.15262318, -0.00411632, 0.13083778, -0.0811692, 0.26280627, -0.037434336, -0.015015555, 0.09534455, 0.051954314, -0.42512083, -0.16444406, 0.13836542, 0.10069935, -0.116963714, 0.17395224, -0.3445682, -0.19291219, -0.05206494, -0.17201999, -0.11340488, 0.2347324, 0.05098681, -0.23252861, -0.25410825, 0.16518736, 0.07516079, -0.09139485, 0.43169996, 0.22813539, 0.1307354, -0.14429799, 0.0805763, -0.35344365, -0.17967914, 0.022510061, -0.41102213, 0.13516422, 0.044297926, 0.08186679, 0.03533948, 0.14363994, 0.37252206, 0.35917833, -0.30963558, -0.04108892, 0.121735014, 0.1888546, -0.06179019, -0.11294965, 0.205444, 0.31555718, 0.11941912, 0.23203902, 0.019453924, 0.025178451, 0.07113684, -0.07433441, 0.21850002, -0.05063764, -0.042092644, 0.3019152, 0.21435574, -0.014558449, 0.009934034, 0.11619258, 0.12624268, -0.41773173, -0.0012811159, -0.14131914, 0.20337018, 0.13497671, 0.004714905, 0.4126909, -0.0976716, -0.32710204, -0.12450581, -0.225299, -0.12437748, -0.30387318, 0.2780857, 0.11051922, 0.07175723, -0.045327816, 0.12620358, -0.3565432, 0.22448577, 0.2278121, 0.27102324, -0.33980355, 0.10331872, -0.032887075, -0.05799463, 0.070839524, -0.26593643, -0.01193367, -0.3267408, 0.3147133, 0.009396369, 0.010949426, 0.21595442, 0.077239744, -0.3791073, -0.34433946, -0.20442681, -0.29677123, 0.13480258, -0.16535735, -0.098813474, -0.09749093, 0.44019863, -0.21284562, -0.16799109, 0.40906537, -0.038319685, 0.15755013, 0.12179234, 0.28633535, 0.14698723, -0.3051876, -0.00822484, 0.059591893, 0.10581278, 0.21221077, -0.02806588, -0.20171465, -0.1919542, -0.054579817, -0.025581906, -0.2984991, 0.21314383, -0.30210808, -0.2832243, -0.26448342, 0.12987587, 0.16565384, 0.056087885, 0.003760007, -0.17050028, 0.12933162, 0.31427944, 0.21952291, -0.05469233, -0.36703756, -0.12910526, 0.16914415, -0.31145522, -0.15726668, -0.14061859, 0.13056578, -0.11227086, 0.3717334, -0.043949682, -0.17861722, -0.3799728, 0.048755378, -0.021247007, 0.18959692, -0.14034663, 0.41547248, 0.09899976, -0.14813133, 0.07512428, 0.22856604, -0.12451845, -0.12289147, -0.17805994, 0.21716873, 0.37867376, -0.33299896, -0.09873342, 0.27493647, 0.055765845, 0.10142621, -0.015416154, 0.24034962, -0.34223056, 0.17939214, 0.36165258, 0.2667329, -0.35912463, 0.41052535, 0.040770918, -0.12780431, 0.109395154, 0.3747686, -0.09568499, -0.1901028, 0.3033764, 0.06054539, 0.15317136, -0.19520281, 0.17724574, -0.21699816, 0.14295678, -0.15630491, 0.03995944, -0.07508831, 0.4395659, -0.057187878, -0.045236003, 0.18166654, -0.10239086, 0.40498814, 0.061890893, 0.19536069, 0.108240515, 0.0023875046, -0.3875909, -0.03046326, 0.11236697, -0.08390719, -0.0066388873, -0.18772168, 0.07728145, -0.089608446, -0.11308471, 0.0034926974, 0.10460939, -0.21761584, -0.17522499, 0.27143174, -0.01003124, -0.41058642, 0.20331898, -0.15285791, 0.15441944, 
0.027293831, -0.1746031, -0.023278628, -0.14225185, -0.09756187, -0.09687166, 0.17323166, 0.30856907, -0.016437516, 0.15379728, 0.17819633, 0.15069957, 0.121199384, 0.114319585, -0.046179034, -0.06682746, 0.043540765, 0.057911627, -0.34814864, 0.38015896, 0.2921212, -0.090283245, -0.28011975, -0.30499107, -0.128161, -0.26182476, -0.071150854, -0.39192906, -0.050559755, -0.22137071, 0.43844223, -0.13818452, 0.3321835, 0.106311224, -0.1559834, -0.21243581, -0.3552272, 0.1485909, -0.06883646, -0.12720662, -0.17607884, 0.22216623, 0.4238738, -0.09264778, -0.011824622, 0.26276436, 0.19886325, -0.23088543, -0.1837929, 0.3850065, 0.08779815, -0.11843672, 0.29625288, 0.3389107, -0.1923542, -0.16599081, -0.15162794, -0.42449048, 0.04191046, -0.412945, 0.2290505, -0.085040435, 0.26539147, -0.037280053, 0.21191974, 0.018777706, -0.11720193, -0.121589035, -0.1681476, -0.0024750012, -0.025160732, -0.10973551, -0.06921049, 0.15165828, -0.18722302, -0.098160915, 0.24467993, 0.19656579, 0.18447332, 0.22563353, -0.20507245, -0.39962733, 0.23614213, 0.15380748, -0.03008222, -0.14542446, 0.16005152, -0.022790123, -0.16909668, 0.16580142, 0.040366463, 0.16948862, 0.045220397, 0.3277363, -0.15750112, 0.27455088, 0.12919593, -0.1789775, -0.0022467473, -0.060246907, 0.16236146, 0.081230395, 0.4373, 0.16427633, -0.3045268, 0.16572572, 0.27862203, 0.05762366, -0.3095561, -0.11352508, -0.25441477, -0.35644975, 0.11599576, 0.2118972, -0.0003260287, -0.13512327, 0.13660721, -0.116464354, -0.16024873, -0.12329313, 0.35374278, -0.3317176, 0.25016016, -0.11147078, -0.12051944, -0.34366626, -0.044625584, -0.41693, 0.04497777, 0.27076802, -0.092689246, -0.054244936, -0.18836677, 0.33865628, 0.12485699, -0.22362384, -0.20595635, 0.15002698, -0.0307091, 0.22672576, -0.12270178, 0.038155746, -0.13968426, 0.18087034, -0.19278212, -0.16927938, 0.30623126, -0.10391087, -0.24520588, -0.11143173, 0.24651334, 0.031446554, -0.2635201, -0.011383196, -0.0042675193, -0.10835279, -0.06273374, 0.103100374, -0.13036351, 0.108532436, -0.22261076, -0.24399605, -0.034284778, 0.3328241, 0.066890635, -0.3992421, 0.007643493, 0.15220505, -0.14786837, -0.21153738, -0.3274781, 0.43243644, 0.3591852, -0.10899618, 0.3719475, 0.092119515, 0.122046605, 0.17238754, 0.03033855, -0.27274343, 0.17248766, -0.020201907, 0.3023988, -0.18804607, -0.34292555, -0.053096414, 0.10075402, 0.1988121, 0.08860179, 0.11112658, 0.008969883, 0.24061614, -0.14631262, 0.047013126, -0.19098388, 0.27426916, 0.17923138, -0.34291998, -0.20554294, -0.07378916, 0.13575043, -0.17010042, 0.16898288, 0.04096703, 0.23059247, 0.13800788, -0.13748047, -0.028642679, -0.04907769, 0.12029269, 0.2874934, -0.25953794, -0.31794205, 0.12934129, -0.006429749, -0.09204989, -0.24960738, 0.023794537, -0.32340166, 0.13708775, -0.35071996, 0.25896454, 0.14582348, 0.33431345, -0.1554625, 0.043164346, 0.025860881, 0.012208344, 0.16301271, -0.35487518, 0.21653983, 0.038739588, 0.241877, -0.27940243, -0.22341305, 0.251298, 0.2840363, 0.07376154, 0.11909827, -0.096951105, 0.118042015, -0.01670288, 0.079482265, -0.29448694, 0.090289034, -0.057404403, 0.32402107, 0.13109502, -0.10369189, 0.19085675, 0.36422023, 0.04942259, -0.005629941, 0.097935304, -0.2881629, 0.35770568, -0.3478608, -0.18953533, -0.094246335, 0.085280925, -0.10875575, 0.08138465, -0.11190597, 0.02630017, 0.010465298, -0.13156603, -0.06619988, -0.079493806, -0.034132253, 0.341511, 0.19336097, -0.18527369, 0.120424, 0.31554234, 0.027456736, 0.15019149, 0.046136346, -0.058114506, -0.37255806, -0.14027794, -0.08035668, -0.35771984, 0.0071923356, 0.057534464, 0.17789015, -0.3064446, 0.031412423, 0.014840789, 0.0968599, 0.41537592, 0.36673236, -0.18722887, -0.22626218, 0.3901142, -0.1599011, -0.11856758, -0.09649303, 0.2662409, -0.011249872, -0.03285034, 0.15252964, 0.13789773, -0.424675, -0.3019665, -0.20796652, 0.29264477, -0.28944182, 0.2551902, 0.14323758, 0.22218251, -0.32271966, -0.0258034, 0.3430538, -0.19453998, 0.091606416, -0.20287967, 0.30290112, 0.04699859, 0.21614692, -0.13476472, 0.0455639, 0.35662985, 0.02568031, 0.13802825, -0.18616185, -0.02508608, -0.118130505, -0.15959877, 0.10922279, -0.16640514, 0.030099448, -0.40190455, 0.20971754, 0.13624063, -0.28194466, 0.05870842, -0.17020485, 0.14856201, -0.14710233, -0.033728115, -0.200229, 0.2240697, 0.15575711, 0.084087804, 0.091723435, -0.003829883, 0.01687899, -0.2049391, 0.16272408, -0.28658313, 0.20104168, -0.048586205, 0.32305983, 0.116998434, 0.098434076, 0.18383916, 0.044377502, 0.11879716, -0.011844871, -0.13333893, -0.15381603, 0.056732308, -0.07026257, -0.27401358, -0.11737549, 0.20570752, -0.16321467, -0.1266627, -0.24985413, 0.12403986, 0.062190652, 0.2861398, -0.110541806, 0.37320194, -0.41959953, -0.26079437, 0.14753583, -0.1558337, 0.40863413, -0.033202328, 0.36392024, -0.20062055, -0.17084718, 0.10680356, 0.24506792, 0.1384317, 0.34620994, 0.14950976, -0.027452154, -0.14568752, -0.088221446, -0.054036837, 0.07285882, 0.35463342, 0.18919803, -0.27513278, -0.21996196, 0.03344171, -0.37026426, 0.11906389, 0.12986551, 0.1255663, -0.117194556, 0.27056885, 0.41804242, 0.21084899, -0.22359622, -0.30021447, 0.015602041, -0.09369994, 0.091409616, -0.29459265, -0.08942035, -0.06539932, 0.09976683, 0.045337763, -0.037236515, 0.17789467, -0.13041162, 0.0058310227, 0.049703516, 0.16926146, -0.35114634, -0.14224645, -0.1115619, 0.16809985, -0.014194635, -0.16389176, -0.025029687, -0.19728197, 0.13154957, -0.06673251, 0.13347457, 0.09244986, 0.22284809, 0.1324537, -0.058734197, 0.029092152, 0.33319533, -0.15459599, 0.023264758, -0.11380043, -0.11282811, -0.06674237, -0.22392774, 0.07982238, -0.37248835, -0.07884955, -0.16718043, -0.14016917, -0.116600014, -0.057178546, -0.097981334, 0.22180091, 0.015663732, 0.2565985, 0.33812463, -0.35986936, 0.061467446, -0.14026313, 0.36941397, -0.3053845, 0.14611055, 0.16012682, 0.21117152, 0.0053824773, 0.42786965, -0.22374451, -0.3773153, -0.19369228, 0.016665703, -0.045552116, -0.20615134, 0.15217736, -0.03311196, 0.1378622, -0.18821257, -0.0071505406, -0.14751063, 0.28764904, 0.00057546236, -0.18597709, -0.122557245, 0.34144607, -0.065558836, -0.23855644, -0.10209668, 0.09452937, 0.14436905, -0.40559867, 0.05439406, -0.078939736, 0.07310023, -0.031457655, 0.43610767, 0.31216717, -0.15230018, -0.33595437, -0.032151714, -0.20111376, -0.15750346, -0.10574821, 0.027871674, 0.09695485, 0.06389919, 0.06107047, 0.09037926, 0.15977053, -0.1874724, 0.13201822, 0.07585743, -0.01042511, -0.3710986, -0.20251532, 0.19607045, -0.1557796, -0.16554888, 0.0052897413, -0.12809913, 0.08060133, 0.17132962, -0.019448161, 0.25859955, -0.2828511, -0.13289702, 0.047522776, -0.22220229, -0.28326496, -0.4019694, -0.19472429, 0.03733487, 0.1544096, -5.395923f-5, 0.16802989, 0.09942447, -0.30840647, 0.35371345, -0.17081094, 0.1417528, 0.054072734, -0.25201565, 0.1616256, 0.26335713, 
-0.16459851, 0.10073876, -0.006048863, 0.24060808, -0.091617025, 0.040157694, 0.28187278, -0.0029710403, 0.043349624, -0.14378013, 0.1479583, 0.07086338, -0.13736011, -0.046150517, 0.14080334, -0.14387235, -0.14811924, -0.14167704, 0.06459302, -0.048843287, 0.061319247, 0.06710165, 0.15069851, -0.13083015, -0.0765374, 0.06517677, -0.0660682, 0.1483683, -0.067613594, 0.146892, -0.06790586, 0.049529392, -0.13350125, 0.068149954, 0.0669475, 0.078118086, -0.06930747, -0.07339384, 0.075408265, -0.06492583, -0.06631159, 0.102842234, -0.19362949, 0.15762927, 0.039961927, -0.32127964, 0.35974953, 0.13946413, 0.048426382, 0.054979134, -0.11568598, 0.1365553, 0.34217566, 0.4171052, -0.11416641, 0.22973311, -0.008464747, 0.5310343, 0.5312507, -0.05628065, -0.19015361, 0.092775896, 0.075131, -0.21427998, 0.08837624, 0.117420614, -0.2658205, 0.2490142, 0.29603818, 0.058667827, -0.28419802, -0.03706303, -0.24228199, 0.16472355, -0.19465648, 0.3229012, 0.056370877, -0.26393357, -0.18443711, 0.11416858, 0.15660863, 0.32444972, 0.5358852, 0.37647265, -0.066232175, -0.31366712, -0.23600298, 0.22645696, 0.008077553, -0.3837643, -0.103478946, 0.24691285, -0.36345276, -0.40471476, 0.4085933, -0.33851776, -0.1426592, 0.27822012, 0.047153514, -0.16652374, 0.07417784, 0.14815368, -0.041149046, -0.34610036, -0.13583638, 0.12837888, -0.25202876, 0.40279254, 0.30266672, 0.44115114, 0.11985864, -0.13522269, 0.37195703, -0.32150096, 0.29908076, 0.26044083, -0.011767192, 0.4783825, 0.07643443, -0.14994961, 0.13756484, -0.38714096, 0.3421218, -0.25598165, 0.34403795, 0.3739817, 0.09069328, -0.1915779, 0.019494927, -0.29595897, -0.5176583, -0.041229505, 0.103251986, 0.2462919, 0.0071001993, -0.27127165, -0.100124255, -0.13554497, -0.124653384, 0.13704897, -0.28340745, -0.29908496, -0.122147195, -0.06464057, 0.07822411, -0.11803203, 0.07949784, -0.18039267, 0.2605619, 0.4211523, 0.08283859, 0.13419095, -0.21640362, 0.03709688, 0.05322013, 0.1189467, 0.24719824, -0.16734408, -0.008114711, 0.055327848, -0.26808923, 0.33366865, 0.17214242, -0.23962137, -0.05233278, -0.19633555, 0.24315508, -0.38059622, 0.077255346, -0.06450808, -0.030319043, 0.0, 0.0]

Different initial conditions give different solutions of course, and this was a fairly good one (and you can play around with the optimization, but from the quick animation it’s fairly clear you won’t do much better than this).

Cheers!

Hi Chris, Thanks so much for the help and figuring out the bug in the loss function. It all runs smoothly and faster now.
One last thing, I was trying to substitute BFGS for a different algorithm, like NadlerMead:

becomes:

res2 = DiffEqFlux.sciml_train(loss, res1.minimizer, NelderMead(),
                                cb=callback, maxiters = 1000)

This results in an error:

ERROR: KeyError: key "x" not found

Do you have any idea what could be causing this? Is there any example on how to interface with Optim solvers or BBO solvers?

Thanks again!

@pkofod does NelderMead have different things in the callback dictionary than the other arguments?

https://github.com/SciML/DiffEqFlux.jl/blob/master/test/layers_sciml.jl#L118-L125

That’s NLopt and Fminbox. For BBO, just use BBO()

Yes it has x_centroid instead which is a point that’s not really a part of the simplex so it’s explicitly named so to avoid confusion (and obviously to generate confusion in other cases :slight_smile: ).

Got it thanks. I’ll take this to PRs and we can munge it to the way DiffEqFlux wants there: https://github.com/SciML/DiffEqFlux.jl/pull/416

2 Likes