Great. I update DiffEqFlux to master, it works!
However, I am playing with some custom layers, and I hit some cases I need some help!
It works fine when I use a pre-defined Dense layer
m = Dense(2,2)
n_ode1(x) = neural_ode(m, x, (0.0, 1.0), BS3())
u = zeros(2)
n_ode1(u)
Tracked 2×8 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
u = param(zeros(2))
n_ode1(u)
Tracked 2×8 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
However, if I build my own layer (following example in Flux Basic.md)
Things do not go well… Did I miss something messing up the type?
function linear(in, out)
W = param(randn(out, in))
x -> W * x
end
linear1 = linear(2, 2)
n_ode2(x) = neural_ode(linear1, x, (0.0, 1.0), BS3())
u = zeros(2)
n_ode2(u)
2×8 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
u = param(zeros(2))
n_ode2(u)
MethodError: no method matching OrdinaryDiffEq.BS3Cache(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Flux.Tracker.TrackedReal{Float64},1}, ::OrdinaryDiffEq.BS3ConstantCache{Flux.Tracker.TrackedReal{Float64},Float32})
Closest candidates are:
OrdinaryDiffEq.BS3Cache(::#66#uType, ::#66#uType, ::#67#rateType, ::#67#rateType, ::#67#rateType, ::#67#rateType, !Matched::#66#uType, !Matched::#66#uType, ::#68#uNoUnitsType, ::#69#TabType) where {#66#uType, #67#rateType, #68#uNoUnitsType, #69#TabType} at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/caches/low_order_rk_caches.jl:116
Stacktrace:
[1] alg_cache(::BS3, ::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Type, ::Type, ::Type, ::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Float32, ::Float32, ::Flux.Tracker.TrackedReal{Float64}, ::Array{Any,1}, ::Bool, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/caches/low_order_rk_caches.jl:137
[2] #__init#258(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float32, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Int64, ::Rational{Int64}, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::Float32, ::Float32, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:235
[3] __init(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:62
[4] #__solve#257(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:6
[5] __solve(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:6 (repeats 5 times)
[6] #solve#429(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqBase/cPqrj/src/solve.jl:39
[7] solve(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqBase/cPqrj/src/solve.jl:27
[8] #diffeq_adjoint#15(::TrackedArray{…,Array{Float64,1}}, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Array{Any,1}, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/layers.jl:48
[9] diffeq_adjoint(::Array{Any,1}, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/layers.jl:47
[10] #neural_ode#23(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Function, ::TrackedArray{…,Array{Float64,1}}, ::Tuple{Float32,Float32}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/neural_de.jl:7
[11] neural_ode(::Function, ::TrackedArray{…,Array{Float64,1}}, ::Tuple{Float32,Float32}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/neural_de.jl:3
[12] n_ode2(::TrackedArray{…,Array{Float64,1}}) at ./In[174]:8
[13] top-level scope at In[179]:2