Using a neural network inside a function for DifferentialEquations.jl solve

I’m trying to use a neural network as a component of a differential equation to be solved by ODEsolve:

begin
	# Import needed packages
	using Flux
	using ForwardDiff
	using DifferentialEquations
end

begin
	function f2(u, p, t)
		m = re(p)
		return 0.98*m(t').*u
	end
	function eval_model(model, t,p)
		u0 = eltype(t)(1.0)
		tspan = eltype(t).((0.0, 1.0))
		prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
		sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
		return Array(sol.u)
	end

	n_in = 1
	n_out = 1
	model = Chain(
            Dense(n_in,10,relu),
            Dense(10,10,relu),
            Dense(10,n_out));
	# Test the eval_model and loss functions
	t = Vector{Float32}(LinRange(0, 1, 10))
	p, re = Flux.destructure(model)
	eval_model(model, t, p)
	# m = re(p)
	# m(t')
end

This produces the following error:

MethodError: no method matching (::Flux.Dense{typeof(NNlib.relu), Matrix{Float32}, Vector{Float32}})(::Float32)
The object of type `Flux.Dense{typeof(NNlib.relu), Matrix{Float32}, Vector{Float32}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Closest candidates are:
  (::Flux.Dense)(::AbstractVecOrMat)
   @ Flux ~/.julia/packages/Flux/3711C/src/layers/basic.jl:196
  (::Flux.Dense)(::AbstractArray)
   @ Flux ~/.julia/packages/Flux/3711C/src/layers/basic.jl:202

Here is what happened, the most recent locations are first:
macro expansion
from 
basic.jl:68
_applychain(layers::Tuple{…}, x::Float32) ...show types...
from 
Flux β†’ basic.jl:68
 
Chain(x::Float32) ...show types...
from 
Flux β†’ basic.jl:65
 
f2(u::Float32, p::Vector{Float32}, t::Float32)
from 
This cell: line 4
		m = re(p)
		return 0.98*m(t').*u
	end
ODEFunction
from 
scimlfunctions.jl:2468
ode_determine_initdt(u0::Float32, t::Float32, tdir::Float32, dtmax::Float32, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SciMLBase.ODEProblem{…}, integrator::OrdinaryDiffEqCore.ODEIntegrator{…}) ...show types...
from 
OrdinaryDiffEqCore β†’ initdt.jl:248
auto_dt_reset!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}) ...show types...
from 
OrdinaryDiffEqCore β†’ integrator_interface.jl:424
handle_dt!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}) ...show types...
from 
OrdinaryDiffEqCore β†’ solve.jl:647
#__init#63(prob::SciMLBase.ODEProblem{…}, alg::OrdinaryDiffEqCore.CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Float32, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::SciMLBase.ODEAliasSpecifier, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{…}) ...show types...
from 
OrdinaryDiffEqCore β†’ solve.jl:609
__init
from 
solve.jl:11
#__solve#62
from 
solve.jl:6
__solve
from 
solve.jl:1
#solve_call#35(_prob::SciMLBase.ODEProblem{…}, args::OrdinaryDiffEqCore.CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…}) ...show types...
from 
DiffEqBase β†’ solve.jl:635
solve_call
from 
solve.jl:592
#solve_up#44
from 
solve.jl:1128
solve_up
from 
solve.jl:1106
#solve#42
from 
solve.jl:1045
solve
from 
solve.jl:1033
#__solve#3
from 
default_alg.jl:48
__solve
from 
default_alg.jl:47
#__solve#63
from 
solve.jl:1437
__solve
from 
solve.jl:1428
#solve_call#35(::SciMLBase.ODEProblem{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…}) ...show types...
from 
DiffEqBase β†’ solve.jl:635
solve_call
from 
solve.jl:592
#solve_up#44(::SciMLBase.ODEProblem{…}, ::Nothing, ::Float32, ::Vector{…}; kwargs::@Kwargs{…}) ...show types...
from 
DiffEqBase β†’ solve.jl:1112
solve_up
from 
solve.jl:1106
#solve#42(::SciMLBase.ODEProblem{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…}) ...show types...
from 
DiffEqBase β†’ solve.jl:1043
eval_model(model::Flux.Chain{…}, t::Vector{…}, p::Vector{…}) ...show types...
from 
This cell: line 10
		prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
		sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
		return Array(sol.u)
from 
This cell: line 29
	p, re = Flux.destructure(model)
	eval_model(model, t, p)
	# m = re(p)

I’ve seen working examples with neural networks being used inside functions for ODEsolve, and it’s not clear where I’m mis-stepping here.

See this part. Flux requires an array input, but t is a scalar. m([t])[1]

1 Like