Calculate gradient of UDE model (without DiffEqFlux.sciml_train)

Hello everyone, I’m interested in using the SciML ecosystem to train universal differential equations models in Julia. However, I don’t want to have to rely on the sciml_train function – I’m interested in explicitly working with the gradients that get calculated during optimization. I’m mostly basing my script off of the “Custom training loops” tutorial from Flux (https://fluxml.ai/Flux.jl/stable/training/training/#Custom-Training-loops-1) and the “neural ODEs” tutorial from DiffEqFlux (https://diffeqflux.sciml.ai/dev/examples/neural_ode_sciml/)

However, whenever I try to calculate a gradient, I get an error of the form MethodError: no method matching similar (::Zygote.Params, ::Int64). Can someone help me out; apologies if this is a dumb question, as I’m new to Julia.

Edit: Adding that I’m working on a simple pendulum model, taken from the “Classical Physics Models” lecture notes.

Thanks!

Code below:

using OrdinaryDiffEq
using LinearAlgebra, Optim
using Flux, DiffEqFlux
using DiffEqSensitivity
using Random

# function to generate network
function initialize_model(input_dim, hidden_dim, output_dim, model_seed)
   Random.seed!(model_seed)

   model = Chain(Dense(input_dim, hidden_dim, tanh),
                 Dense(hidden_dim, output_dim))
   weights = params(model)
   return model, weights
end

# function for UDE
function dudt_(u, p, t)
   θ = u[1]
   dθ = u[2]
   z = model(u)[1]

   [dθ, z]
end

# function to set up and solve ODE with current NN state
function predict_nn(weights)
   prob = ODEProblem(dudt_, u0, tspan, weights)
   solution = solve(prob, Vern7(), u0=u0, p=weights, saveat=0.1,
                            sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())))
   Array(solution)
end

# evaluate quality of prediction with MSE
function loss(weights)
   pred = predict_nn(weights)
   sum(abs2, ode_data' .- pred)
end

# generate data from pendulum example from here
ode_data = ...

# initialize model
input_dim, hidden_dim, output_dim, model_seed = 2, 32, 1, 1234
model, weights = initialize_model(input_dim, hidden_dim, output_dim, model_seed)

# calculate gradient
gs = gradient(weights) do
   training_loss = loss(weights)
   return training_loss
end

The full stack-trace:

Stacktrace:
 [1] ODEAdjointProblem(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::DiffEqSensitivity.var"#df#134"{Array{Float64,2},Array{Float64,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; checkpoints::Array{Float64,1}, callback::CallbackSet{Tuple{},Tuple{}}, reltol::Float64, abstol::Float64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/interpolating_adjoint.jl:170
 [2] _adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, ::Vern7, ::DiffEqSensitivity.var"#df#134"{Array{Float64,2},Array{Float64,1},Colon}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float64,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:17
 [3] adjoint_sensitivities(::ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},false,Zygote.Params,ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},Vern7,OrdinaryDiffEq.InterpolationData{ODEFunction{false,typeof(dudt_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern7ConstantCache{OrdinaryDiffEq.Vern7Tableau{Float64,Float64}}},DiffEqBase.DEStats}, ::Vern7, ::Vararg{Any,N} where N; sensealg::InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool}, kwargs::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol,Symbol},NamedTuple{(:abstol, :reltol),Tuple{Float64,Float64}}}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:6
 [4] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon})(::Array{Float64,2}) at /.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:144
 [5] (::DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55
 [6] (::Zygote.var"#150#151"{DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
 [7] (::Zygote.var"#1681#back#152"{Zygote.var"#150#151"{DiffEqBase.var"#478#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{Vern7,InterpolatingAdjoint{0,true,Val{:central},ReverseDiffVJP{false},Bool},Array{Float64,1},Zygote.Params,Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [8] #solve#457 at /.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
 [9] (::typeof(∂(#solve#457)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [10] (::Zygote.var"#150#151"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/lib/lib.jl:191
 [11] (::Zygote.var"#1681#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [12] (::typeof(∂(solve##kw)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [13] predict_nn at /udes/pe_train_model_single_2.jl:32 [inlined]
 [14] (::typeof(∂(predict_nn)))(::Array{Float64,2}) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [15] loss at /udes/pe_train_model_single_2.jl:39 [inlined]
 [16] #13 at ./none:2 [inlined]
 [17] (::typeof(∂(#13)))(::Float64) at /.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(#13))})(::Float64) at /.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
 [19] gradient(::Function, ::Zygote.Params) at /.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [20] top-level scope at none:0

Your issue was that you were using Flux but not using it in the appropriate style. Essentially its implicit global state was getting in the way. It’s much easier to use FastChain, which I show here:

using OrdinaryDiffEq
using LinearAlgebra, Optim
using Flux, DiffEqFlux
using DiffEqSensitivity
using Random

#Constants
const g = 9.81
L = 1.0

#Initial Conditions
u0 = [0,π/2]
tspan = (0.0,6.3)

#Define the problem
function simplependulum(du,u,p,t)
    θ = u[1]
    dθ = u[2]
    du[1] = dθ
    du[2] = -(g/L)*sin(θ)
end

#Pass to solvers
prob = ODEProblem(simplependulum, u₀, tspan)
ode_data = Array(solve(prob,Tsit5(), saveat=0.1))

# function to generate network
function initialize_model(input_dim, hidden_dim, output_dim, model_seed)
   Random.seed!(model_seed)

   model = FastChain(FastDense(input_dim, hidden_dim, tanh),
                     FastDense(hidden_dim, output_dim))
   weights = initial_params(model)
   return model, weights
end

# function for UDE
function dudt_(u, p, t)
   θ = u[1]
   dθ = u[2]
   z = model(u,p)[1]

   [dθ, z]
end

# function to set up and solve ODE with current NN state
function predict_nn(weights)
   prob = ODEProblem(dudt_, u0, tspan, weights)
   solution = solve(prob, Vern7(), saveat=0.1,
                    sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))
   Array(solution)
end

# evaluate quality of prediction with MSE
function loss(weights)
   pred = predict_nn(weights)
   sum(abs2, ode_data .- pred)
end

# initialize model
input_dim, hidden_dim, output_dim, model_seed = 2, 32, 1, 1234
model, weights = initialize_model(input_dim, hidden_dim, output_dim, model_seed)

Flux.gradient(loss,weights)

If you want to use Flux neural networks directly, you’ll want do use the destructure/restructure approach shown in https://diffeqflux.sciml.ai/dev/examples/neural_ode_flux/ with a full description in https://diffeqflux.sciml.ai/dev/Flux/

@ChrisRackauckas Thanks for the response! This does fix my problem.

Three smaller followup questions:

  1. Unlike my attempt to use the do syntax. Flux.gradient(loss,weights) only returns the gradient, not the gradient and the loss. Do you just have to calculate the loss separately if you want to log it?
  2. Is the stack-trace I posted useful for diagnosing the issue I was having, or is this just something a user has to understand about Flux?
  3. Is Flux’s “global state” described in detail somewhere? I don’t see it mentioned on the page you linked: https://diffeqflux.sciml.ai/dev/examples/neural_ode_flux/

That’s fine. Do your style, or use Flux.pullback. You just didn’t share usable code. Please next time share code I can copy paste and run (it’ll make it much easier to figure out: I just kind of took a guess at what you were doing).

It was useful. It’s saying MethodError: no method matching similar (::Zygote.Params, ::Int64), so there was no way to create a new array from the Zygote.Params. If you checked the types you were using, weights = params(model) is a Zygote.Params, which doesn’t act like a normal array. You gave this to DiffEq and tried to use it as a normal array, and it failed because it doesn’t have such behaviors.

That style avoids using it. It’s the whole Zygote.Params thing.