Hello all,

I have been trying to setup NeuralODE to learn the behavior of a stiff ODE but in the process I am running into this strange error. I have tried using all Float32 values only but it seems that there is some issue with the implementation.

I am new to Julia and am trying to learn so any help would be greatly appreciated.

Thank you

Here is the code:

```
using ModelingToolkit, DifferentialEquations, Plots, Lux, Optimization, OptimizationOptimJL, Random, Plots, DiffEqFlux, ComponentArrays
using ModelingToolkit: t_nounits as t, D_nounits as D
#
params = @parameters begin
R = 8.31, [description = "Gas constant, in J/(K.mol)"]
T = 1573, [description = "temperature, K"]
Q = 453_000, [description = "heat of..., J/mol"]
M = 135e3, [description = "..., in MPa"]
R°_p = 10^(-0.95), [description = "..., 1/(MPa^5.s)"]
R°_gb = 10^3.53, [description = "..., 1/(MPa^4.s)"]
A° = 10^6.94, [description = "..., 1/(MPa^2.s)"]
d = 400e-6, [description = "grain size, m"]
β = 2, [description = "..., ..."]
μ = 65e3, [description = "..., MPa"]
b = 5e-10, [description = "..., m"]
Exp = exp(-Q/(R*T))
#
σ_T_max = Float32(1.8e3)
σ_d = β*μ*b/d
σ_ref = Float32(3.1e3) # σ∗_p*R*T/Q
σ = 100
end
#
vars = @variables begin
σ_T(t) = 1
ϵ°(t) # = A°*Exp*1^2*sinh((σ - σ_T - σ_d)/σ_ref)
end
#
eqs = [ϵ° ~ A°*Exp*σ_T^2*sinh((σ - σ_T - σ_d)/σ_ref),
D(σ_T) ~ M*( ϵ°*(σ_T + σ_d)/σ_T - abs(ϵ°)*σ_T/σ_T_max - R°_p*Exp*σ_T^5 - R°_gb*Exp*σ_T^3*σ_d )]
@named sys_eq = ODESystem(eqs,t,vars,params)
sys = structural_simplify(sys_eq)
tspan = (0.0,20000.0)
datasize = 50
tsteps = range(tspan[1], tspan[2], length = datasize)
prob = ODEProblem(sys,[],tspan)
sol_array = Array(solve(prob,TRBDF2(),saveat=tsteps))
sol_array = Float32.(sol_array)
#sol=solve(prob,TRBDF2(),saveat=tsteps)
plot(sol, lw=2.5, lc=:blue,legend=:topleft, xlabel="time in sec",ylabel="Stress in MPa")
# For saving solution outputs
using DataFrames
df = DataFrame(sol)
using CSV
CSV.write("sol_values.csv", df)
# shaping information for the NN
math_law(x) = 2.0.*x.^5-1.0.*x.^3 # Just some approximate form of the eqn
# Construct the layer
dudt2 = Lux.Chain(x -> math_law(x),Lux.Dense(1, 50, tanh),Lux.Dense(50, 1))
rng = Random.default_rng()
# Initialization
p, st = Lux.setup(rng, dudt2)
#_______________________________________________________________________________
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
#______________________________________________________________________________
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
#______________________________________________________________________________
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2.(sol_array .- pred))
return loss, pred
end
#______________________________________________________________________________
callback = function (p, l, pred; doplot = false)
# plot current prediction against data
if doplot
plt = Plots.scatter(tsteps, sol_array[1,:], label = "data")
Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
display(plot(plt))
end
return false
end
#______________________________________________________________________________
pinit = ComponentArray(p)
# Initial condition
u0 = Float32[1.0]
callback(pinit, loss_neuralode(pinit)...; doplot=true)
#______________________________________________________________________________
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,Adam(0.05),callback = callback,maxiters = 100)
# Retrain using the LBFGS optimizer
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,Optim.BFGS(initial_stepnorm=0.01),callback=callback,allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
#_______________________
prob_neuralode2 = NeuralODE(dudt2, tspan2, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode2(forecastu0, p, st)[1])
end
```