Recently, I have tried to use NODE in BVP to learn the parameter(the code is below). I have found that “DiffEqFlux.sciml_train” is ok to train while there will be a warning that AD methods failed, using numerical differentiation. To debug, try ForwardDiff.gradient(loss, θ) or Zygote.gradient(loss, θ). I wounder why the AD methods failed.
When I use Zygote.gradient(loss, θ) to check, there will be an error that No adjoint rules exist. Check that you added using DiffEqSensitivity
, while DiffEqSensivity is also failed. Besides, Flux.train! seems fail too that the same error will arise, and I have to use DiffEqFlux.sciml_train. Could you guys give me some advice?
using DifferentialEquations
using BoundaryValueDiffEq
using Flux
using DiffEqFlux
using Statistics, LinearAlgebra, IterTools
const g = 9.81
L = 1.0
tspan = (0.0,pi/2)
function simplependulum!(du,u,p,t)
g,L=p
θ = u[1]
dθ = u[2]
du[1] = dθ
du[2] = -(g/L)*sin(θ)
end
#the boundary I want to fix,which I also fix the initial condition of u[1], so only the initial u[2] will change
function bc(a,p)
function bc1!(residual, u, p, t)
residual[1] = u[end÷2][1] + pi/2 # the solution at the middle of the time span should be -pi/2
residual[2] = u[1][1] - a # fix the initial condition of u[1], so only the initial u[2] will change
end
end
function predict_n_ode(a,p)
bvp1 = BVProblem(simplependulum!, bc(a,p), [a,pi/2], tspan)
Array(solve(bvp1, GeneralMIRK4(), p=p, dt=0.05))[:,end]
end
pr=[9.81,1.0] #real parameters
u0=[i for i in 1.0:0.1:2.0] # different initial value of θ to input
data=predict_n_ode.(u0,pr for i in length(u0)) # the data for training
function loss_n_ode(p)
pred = predict_n_ode.(u0,p for i in 1:1:length(u0))
loss = mean(norm.(pred.-data).^2)
@show loss
end
ph=rand(2)
result_bvp = DiffEqFlux.sciml_train(loss_n_ode, ph, maxiters = 5)
##use Flux.train!
datax=reshape(u0, 1, :)
datay=reshape(data, 1, :)
datat=Flux.Data.DataLoader((datax, datay), batchsize=1, shuffle=true)
function loss_n_ode1(x,y,p)
loss = (mean(sum((predict_n_ode(x[i],p)-y[i]).^2) for i in 1:length(x)))
@show loss
end
result_bvp1=Flux.train!((x,y)->loss_n_ode1(x,y,ph),Flux.params(ph),ncycle(datat,1),ADAM(0.05))