Provide parameters and user-defined functions to neural_ode in DiffEqFlux.jl

diffeq
flux
#1

Hello,

First, I would like to thanks the developers of DiffEqFlux.jl, this is an amazing work!

I want to use the function neural_ode to model the right-hand side of an ODE for transport of a particle :
\frac{dx}{dt} = F( V(x), p)
where V is a given field of space and F a nonlinear function of the inputs V(x) and p some parameters of the particle (these parameters shouldn’t be optimized)

I don’t know how to provide these nonlearnable parameters to neural_ode

Below is what I want to do:

using DifferentialEquations
using Flux, DiffEqFlux

u0 = 1.0f0
V = x -> cos(x)

# 0.5 is a parameter of the particle
Input = [u0, V, 0.5f0]


dudt = Chain(y -> vcat(p[2](y), p[3]), 
             Dense(2,1,relu),
             Dense(1,1))


tspan = (0.0f0, 2.0f0)
n_ode = x-> neural_ode(dudt, x[1], tspan, ABM54(),
    dt = 0.05f0, adpative = false, p = x[2:3])

Input
3-element Array{Any,1}:
 1.0f0                               
  getfield(Main, Symbol("##25#26"))()
 0.5f0                      

n_ode(Input)
MethodError: no method matching similar(::Float32)
Closest candidates are:
  similar(!Matched::ZMQ.Message, !Matched::Type{T}, !Matched::Tuple{Vararg{Int64,N}} where N) where T at /u/home/m/mleprovo/.julia/packages/ZMQ/ABGOx/src/message.jl:93
  similar(!Matched::DataStructures.IntSet) at deprecated.jl:53
  similar(!Matched::Sundials.NVector) at /u/home/m/mleprovo/.julia/packages/Sundials/KYRgQ/src/nvector_wrapper.jl:69
  ...

Stacktrace:
 [1] alg_cache(::ABM54, ::Float32, ::Float32, ::Type, ::Type, ::Type, ::Float32, ::Float32, ::ODEFunction{true,getfield(DiffEqFlux, Symbol("#dudt_#24")){Chain{Tuple{getfield(Main, Symbol("##27#28")),Dense{typeof(relu),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}}}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Float32, ::Float32, ::Float32, ::Array{Float32,1}, ::Bool, ::Type{Val{true}}) at /u/home/m/mleprovo/.julia/packages/OrdinaryDiffEq/HJOah/src/caches/adams_bashforth_moulton_caches.jl:244
0 Likes

#2

The solution is more generally https://github.com/JuliaDiffEq/DiffEqFlux.jl/issues/15 which we are currently working on.

0 Likes