Hello everyone, I’m interested in using the SciML ecosystem to train universal differential equations models in Julia. However, I don’t want to have to rely on the sciml_train
function – I’m interested in explicitly working with the gradients that get calculated during optimization. I’m mostly basing my script off of the “Custom training loops” tutorial from Flux (Training · Flux) and the “neural ODEs” tutorial from DiffEqFlux (https://diffeqflux.sciml.ai/dev/examples/neural_ode_sciml/)
However, whenever I try to calculate a gradient, I get an error of the form MethodError: no method matching similar (::Zygote.Params, ::Int64)
. Can someone help me out; apologies if this is a dumb question, as I’m new to Julia.
Edit: Adding that I’m working on a simple pendulum model, taken from the “Classical Physics Models” lecture notes.
Thanks!
Code below:
using OrdinaryDiffEq
using LinearAlgebra, Optim
using Flux, DiffEqFlux
using DiffEqSensitivity
using Random
# function to generate network
function initialize_model(input_dim, hidden_dim, output_dim, model_seed)
Random.seed!(model_seed)
model = Chain(Dense(input_dim, hidden_dim, tanh),
Dense(hidden_dim, output_dim))
weights = params(model)
return model, weights
end
# function for UDE
function dudt_(u, p, t)
θ = u[1]
dθ = u[2]
z = model(u)[1]
[dθ, z]
end
# function to set up and solve ODE with current NN state
function predict_nn(weights)
prob = ODEProblem(dudt_, u0, tspan, weights)
solution = solve(prob, Vern7(), u0=u0, p=weights, saveat=0.1,
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
Array(solution)
end
# evaluate quality of prediction with MSE
function loss(weights)
pred = predict_nn(weights)
sum(abs2, ode_data' .- pred)
end
# generate data from pendulum example from here
ode_data = ...
# initialize model
input_dim, hidden_dim, output_dim, model_seed = 2, 32, 1, 1234
model, weights = initialize_model(input_dim, hidden_dim, output_dim, model_seed)
# calculate gradient
gs = gradient(weights) do
training_loss = loss(weights)
return training_loss
end
The full stack-trace:
Stacktrace:
[1] ODEAdjointProblem(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::DiffEqSensitivity.var"#df#134"{Array{Float64,2},Array{Float64,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; checkpoints::Array{Float64,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/interpolating_adjoint.jl:170
[2] _adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Vern7, ::DiffEqSensitivity.var"#df#134"{Array{Float64,2},Array{Float64,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float64,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:17
[3] adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::Vern7, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol,Symbol},NamedTuple{(:abstol, :reltol),Tuple{Float64,Float64}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:6
[4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon})(::Array{Float64,2}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:144
[5] (::DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55
[6] (::Zygote.var"#150#151"{DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
[7] (::Zygote.var"#1681#back#152"{Zygote.var"#150#151"{DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[8] #solve#457 at /.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
[9] (::typeof(∂(#solve#457)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[10] (::Zygote.var"#150#151"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
[11] (::Zygote.var"#1681#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[12] (::typeof(∂(solve##kw)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[13] predict_nn at /udes/pe_train_model_single_2.jl:32 [inlined]
[14] (::typeof(∂(predict_nn)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[15] loss at /udes/pe_train_model_single_2.jl:39 [inlined]
[16] #13 at ./none:2 [inlined]
[17] (::typeof(∂(#13)))(::Float64) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[18] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(#13))})(::Float64) at /.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
[19] gradient(::Function, ::Zygote.Params) at /.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
[20] top-level scope at none:0