Error using example for NeuralODE

Getting an error in the last optimization step :

result_neuralode2 = Optimization.solve(
    optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false)

**Error message : **
**ERROR: Output should be scalar; gradients are not defined for output (0.16004305f0, Float32[2.0 1.9136947 1.7503117 1.2632618 0.606817 -0.059315965 -0.72503924 -1.3093071 -1.615921 -1.730963 -1.7286443 -1.7037328 -1.6729282 -1.5816362 -1.3799951 **
-1.04659 -0.65896845 -0.25841323 0.13459843 0.52109444 0.8948952 1.1780137 1.3319784 1.4162561 1.4492586 1.4433123 1.4268215 1.4080843 1.3636386 1.2272812; 0.0 0.7158319 1.3774631 1.7345353 1.8257833 1.8375918 1.8398279 1.7201138 1.370624 0.93096924 0.40284145 -0.12198502 -0.6398629 -1.1274402 -1.4788264 -1.6393745 -1.6478968 -1.6037871 -1.5609413 -1.5248863 -1.463131 -1.3198652 -1.1052285 -0.85746014 -0.56652415 -0.24201494 0.082114115 0.39672565 0.71082836 1.0200887])

Code is from the Julia NeuralODE documentation :
Link : Neural Ordinary Differential Equations · DiffEqFlux.jl

Code :

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
      OptimizationOptimisers, Random, Plots

rng = Xoshiro(0)
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = true)
    println(l)
    # plot current prediction against data
    if doplot
        plt = scatter(tsteps, ode_data[1, :]; label = "data")
        scatter!(plt, tsteps, pred[1, :]; label = "prediction")
        display(plot(plt))
    end
    return false
end

pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot = true)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.05); callback = callback, maxiters = 300)

optprob2 = remake(optprob; u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false)

callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true)

This is the versioninfo()

julia> versioninfo()
Julia Version 1.11.1
Commit 8f5b7ca12a (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 16 × 11th Gen Intel(R) Core(TM) i7-11850H @ 2.50GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, tigerlake)
Threads: 1 default, 0 interactive, 1 GC (on 16 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS =

I have seen similar posts but cant really figure out how to fix it.

Did anyone try the code with Julia version 1.11.1 ?

It’s related to Optimization.jl v4.0 update: MNIST Example Throws Exception at Optimization.solve (in v4.0) · Issue #954 · SciML/DiffEqFlux.jl · GitHub

Try this code

using Lux, DiffEqFlux, OrdinaryDiffEq, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots
rng = Random.Xoshiro(0)

# True solution: $u^3$ and multiplied by a matrix
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

# Generate data from the true function
u0 = [2.0; 0.0]
datasize = 31
tspan = (0.0, 1.5)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

# Define a `NeuralODE` problem with a neural network from `Lux.jl`.
dudt2 = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, dudt2) |> f64
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

# Predicted output
predict_neuralode(p) = Array(prob_neuralode(u0, p, st)[1])

# Loss function
# Optimization.jl v4 only accept a scalar output
function loss_neuralode(p)
    pred = predict_neuralode(p)
    l2loss = sum(abs2, ode_data .- pred)
    return l2loss
end

# Callback function
anim = Animation()
lossrecord=Float64[]
callback = function (state, l; doplot = true)
    if doplot
        pred = predict_neuralode(state.u)
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        frame(anim)
        push!(lossrecord, l)
    else
        println(l)
    end
    return false
end

# Try the callback function to see if it works.
pinit = ComponentArray(p)
callback((; u = pinit), loss_neuralode(pinit); doplot=false)

# Use https://github.com/SciML/Optimization.jl to solve the problem and https://github.com/FluxML/Zygote.jl for automatic differentiation (AD).
adtype = Optimization.AutoZygote()

# Define a [function](https://docs.sciml.ai/Optimization/stable/API/optimization_function/) to optimize with AD.
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)

# Define an `OptimizationProblem`
optprob = Optimization.OptimizationProblem(optf, pinit)

# Solve the `OptimizationProblem` using the ADAM optimizer first to get a rough estimate.
result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.Adam(0.05),
    callback = callback,
    maxiters = 300
)

println("Loss is: ", loss_neuralode(result_neuralode.u))

# Use another optimizer (BFGS) to refine the solution.
optprob2 = remake(optprob; u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(; initial_stepnorm = 0.01),
    callback = callback,
    allow_f_increases = false
)

println("Loss is: ", loss_neuralode(result_neuralode2.u))

# Visualize the fitting process
mp4(anim, fps=15)

#---
lossrecord
plot(lossrecord[1:300], xlabel="Iters", ylabel="Loss", lab="Adam", yscale=:log10)
plot!(300:length(lossrecord), lossrecord[300:end], lab="BFGS")
1 Like

Thanks !
It works well :slight_smile:
So, the trick was to return only the “loss” from the “loss_neuralode” function ?
I also learnt new ways of plotting and creating animations.
Thanks for sharing @Wen-Wei_Tseng

1 Like