[DiffEqFlux] How to do composition of layers?

How to do composition of layers in DiffEqFlux?

For example, I want do composite a layer of neural ODE with a layer of NN:

dudt = Chain(Dense(2,40,tanh), Dense(40,2))
tspan = (0.0f0,1.0f0)
t = range(tspan[1],tspan[2],length=datasize)
n_ode(x) = neural_ode(dudt, x, tspan, Tsit5(), saveat=t)

my_nn = Chain(Dense(2, 10, tanh), Dense(10,2))

Then, I can do

my_nn(n_ode([1.0, 2.0]))

Tracked 2×30 Array{Float32,2}:
 -0.432643  -0.432427  -0.432226  …  -0.430269  -0.43025  -0.430231
  1.0084     1.01155    1.01469       1.09137    1.0943    1.09721 

However, I cannot do the other way around

n_ode(my_nn([1.0, 2.0]))

u0 is not currently differentiable.

Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] #neural_ode#23(::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}, ::Function, ::Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}, ::TrackedArray{…,Array{Float32,1}}, ::Tuple{Float32,Float32}, ::Tsit5) at /Users/po/.julia/packages/DiffEqFlux/1w1tX/src/Flux/neural_de.jl:3
 [3] (::getfield(DiffEqFlux, Symbol("#kw##neural_ode")))(::NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}, ::typeof(neural_ode), ::Chain{Tuple{Dense{typeof(tanh),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}, ::TrackedArray{…,Array{Float32,1}}, ::Tuple{Float32,Float32}, ::Tsit5) at ./none:0
 [4] n_ode(::TrackedArray{…,Array{Float32,1}}) at ./In[270]:4
 [5] top-level scope at In[275]:1

Does it mean that neural_ode currently not support composition? Or do I make any mistake in my code?

I think we might need to tag. We already did a fix to that u0 thing

Great. I update DiffEqFlux to master, it works!

However, I am playing with some custom layers, and I hit some cases I need some help!

It works fine when I use a pre-defined Dense layer

m = Dense(2,2)
n_ode1(x) = neural_ode(m, x, (0.0, 1.0), BS3())

u = zeros(2)
n_ode1(u)
Tracked 2×8 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

u = param(zeros(2))
n_ode1(u)
Tracked 2×8 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

However, if I build my own layer (following example in Flux Basic.md)
Things do not go well… Did I miss something messing up the type?

function linear(in, out)
  W = param(randn(out, in))
  x -> W * x
end
linear1 = linear(2, 2)
n_ode2(x) = neural_ode(linear1, x, (0.0, 1.0), BS3())

u = zeros(2)
n_ode2(u)
2×8 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

u = param(zeros(2))
n_ode2(u)

MethodError: no method matching OrdinaryDiffEq.BS3Cache(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Flux.Tracker.TrackedReal{Float64},1}, ::OrdinaryDiffEq.BS3ConstantCache{Flux.Tracker.TrackedReal{Float64},Float32})
Closest candidates are:
  OrdinaryDiffEq.BS3Cache(::#66#uType, ::#66#uType, ::#67#rateType, ::#67#rateType, ::#67#rateType, ::#67#rateType, !Matched::#66#uType, !Matched::#66#uType, ::#68#uNoUnitsType, ::#69#TabType) where {#66#uType, #67#rateType, #68#uNoUnitsType, #69#TabType} at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/caches/low_order_rk_caches.jl:116

Stacktrace:
 [1] alg_cache(::BS3, ::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Type, ::Type, ::Type, ::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Float32, ::Float32, ::Flux.Tracker.TrackedReal{Float64}, ::Array{Any,1}, ::Bool, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/caches/low_order_rk_caches.jl:137
 [2] #__init#258(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Nothing, ::Bool, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float32, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Int64, ::Rational{Int64}, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::Float32, ::Float32, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:235
 [3] __init(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:62
 [4] #__solve#257(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:6
 [5] __solve(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/po/.julia/packages/OrdinaryDiffEq/miOSH/src/solve.jl:6 (repeats 5 times)
 [6] #solve#429(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqBase/cPqrj/src/solve.jl:39
 [7] solve(::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqBase/cPqrj/src/solve.jl:27
 [8] #diffeq_adjoint#15(::TrackedArray{…,Array{Float64,1}}, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Array{Any,1}, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/layers.jl:48
 [9] diffeq_adjoint(::Array{Any,1}, ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float32,Float32},true,Array{Any,1},ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){getfield(Main, Symbol("##27#28")){TrackedArray{…,Array{Float64,2}}}},UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing,DiffEqBase.StandardODEProblem}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/layers.jl:47
 [10] #neural_ode#23(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Function, ::TrackedArray{…,Array{Float64,1}}, ::Tuple{Float32,Float32}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/neural_de.jl:7
 [11] neural_ode(::Function, ::TrackedArray{…,Array{Float64,1}}, ::Tuple{Float32,Float32}, ::BS3) at /Users/po/.julia/packages/DiffEqFlux/umAKa/src/Flux/neural_de.jl:3
 [12] n_ode2(::TrackedArray{…,Array{Float64,1}}) at ./In[174]:8
 [13] top-level scope at In[179]:2

No, this is something I’m trying to work out with @MikeInnes. Essentially, things like similar on a TrackedArray don’t return a TrackedArray, while keeping the same type when mutable is done for every other type in Julia, making it difficult to generically handle TrackedArray. Array{TrackedReal} works fine though, but I’m not sure that’s the right route here.

I see.

For the custom layer, how should I hack it to make it work?