I just looked at the code posted in the current README on GitHub, and I still get an error regarding adjoint. First I note that the blog implements predict_rd()
using diffeq_rd
, but the current README implements a function predict_adjoint
using diffeq_adjoint
.
Since the error message I got referenced a missing adjoint, I was hoping this would fix the issue. However, I get the same failure with a msg indicating a problem with missing adjoint.
First I’ll post my Pkg “status”, then I’ll copy in the error/stacktrace. I’m using Julia 1.3 with only a small number of packages installed:
(v1.3) pkg> st
Status `~/.julia/environments/v1.3/Project.toml`
[c52e3926] Atom v0.11.3
[aae7a2af] DiffEqFlux v0.8.1
[0c46a032] DifferentialEquations v6.9.0
[587475ba] Flux v0.9.0
[e5e0dc1b] Juno v0.7.2
[91a5bcdd] Plots v0.28.2
Here’s the error info:
julia> Flux.train!(loss_adjoint, params, data, opt, cb = cb)
ERROR: Need an adjoint for constructor ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats}. Gradient is of type Array{Float64,2}
Stacktrace:
[1] (::Zygote.Jnew{ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing,false})(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:265
[2] (::Zygote.var"#316#back#172"{Zygote.Jnew{ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats},Nothing,false}})(::Array{Float64,2}) at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[3] ODESolution at /Users/rick/.julia/packages/DiffEqBase/DqkH4/src/solutions/ode_solutions.jl:2 [inlined]
[4] (::typeof(∂(ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem},Tsit5,OrdinaryDiffEq.InterpolationData{ODEFunction{true,typeof(lotka_volterra),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Tsit5ConstantCache{Float64,Float64}}},DiffEqBase.DEStats})))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[5] solution_new_retcode at /Users/rick/.julia/packages/DiffEqBase/DqkH4/src/solutions/ode_solutions.jl:93 [inlined]
[6] (::typeof(∂(solution_new_retcode)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[7] solve! at /Users/rick/.julia/packages/OrdinaryDiffEq/BhP0W/src/solve.jl:372 [inlined]
[8] (::typeof(∂(solve!)))(::Nothing) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[9] #__solve#345 at /Users/rick/.julia/packages/OrdinaryDiffEq/BhP0W/src/solve.jl:5 [inlined]
[10] (::typeof(∂(#__solve#345)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[11] #153 at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142 [inlined]
[12] #283#back at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[13] #__solve at ./none:0 [inlined]
[14] (::typeof(∂(#__solve)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[15] #153 at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142 [inlined]
[16] (::Zygote.var"#283#back#155"{Zygote.var"#153#154"{typeof(∂(#__solve)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[17] #solve#381 at /Users/rick/.julia/packages/DiffEqBase/DqkH4/src/solve.jl:39 [inlined]
[18] (::typeof(∂(#solve#381)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[19] (::Zygote.var"#153#154"{typeof(∂(#solve#381)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142
[20] #283#back at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[21] #solve at ./none:0 [inlined]
[22] (::typeof(∂(#solve)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[23] #153 at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142 [inlined]
[24] (::Zygote.var"#283#back#155"{Zygote.var"#153#154"{typeof(∂(#solve)),Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[25] #diffeq_adjoint#21 at /Users/rick/.julia/packages/DiffEqFlux/UcpUz/src/Flux/layers.jl:54 [inlined]
[26] (::typeof(∂(#diffeq_adjoint#21)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[27] (::Zygote.var"#153#154"{typeof(∂(#diffeq_adjoint#21)),Tuple{NTuple{5,Nothing},Tuple{Nothing}}})(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142
[28] (::Zygote.var"#283#back#155"{Zygote.var"#153#154"{typeof(∂(#diffeq_adjoint#21)),Tuple{NTuple{5,Nothing},Tuple{Nothing}}}})(::Array{Float64,2}) at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[29] #diffeq_adjoint at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0 [inlined]
[30] (::typeof(∂(#diffeq_adjoint)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[31] predict_adjoint at /Users/rick/src/julia_misc/ODE_from_blog/Tests_from_ReadMe.jl:23 [inlined]
[32] (::typeof(∂(predict_adjoint)))(::Array{Float64,2}) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[33] loss_adjoint at /Users/rick/src/julia_misc/ODE_from_blog/Tests_from_ReadMe.jl:26 [inlined]
[34] (::typeof(∂(loss_adjoint)))(::Float64) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
[35] #153 at /Users/rick/.julia/packages/Zygote/8dVxG/src/lib/lib.jl:142 [inlined]
[36] #283#back at /Users/rick/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [inlined]
[37] #14 at /Users/rick/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:69 [inlined]
[38] (::Zygote.var"#38#39"{Zygote.Params,Zygote.Context,typeof(∂(#14))})(::Float64) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:101
[39] gradient(::Function, ::Zygote.Params) at /Users/rick/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:47
[40] macro expansion at /Users/rick/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:68 [inlined]
[41] macro expansion at /Users/rick/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
[42] #train!#12(::var"#31#32", ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at /Users/rick/.julia/packages/Flux/oX9Pi/src/optimise/train.jl:66
[43] (::Flux.Optimise.var"#kw##train!")(::NamedTuple{(:cb,),Tuple{var"#31#32"}}, ::typeof(Flux.Optimise.train!), ::Function, ::Zygote.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, ::ADAM) at ./none:0
[44] top-level scope at none:0