Iβm trying to use a neural network as a component of a differential equation to be solved by ODEsolve:
begin
# Import needed packages
using Flux
using ForwardDiff
using DifferentialEquations
end
begin
function f2(u, p, t)
m = re(p)
return 0.98*m(t').*u
end
function eval_model(model, t,p)
u0 = eltype(t)(1.0)
tspan = eltype(t).((0.0, 1.0))
prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
return Array(sol.u)
end
n_in = 1
n_out = 1
model = Chain(
Dense(n_in,10,relu),
Dense(10,10,relu),
Dense(10,n_out));
# Test the eval_model and loss functions
t = Vector{Float32}(LinRange(0, 1, 10))
p, re = Flux.destructure(model)
eval_model(model, t, p)
# m = re(p)
# m(t')
end
This produces the following error:
MethodError: no method matching (::Flux.Dense{typeof(NNlib.relu), Matrix{Float32}, Vector{Float32}})(::Float32)
The object of type `Flux.Dense{typeof(NNlib.relu), Matrix{Float32}, Vector{Float32}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Closest candidates are:
(::Flux.Dense)(::AbstractVecOrMat)
@ Flux ~/.julia/packages/Flux/3711C/src/layers/basic.jl:196
(::Flux.Dense)(::AbstractArray)
@ Flux ~/.julia/packages/Flux/3711C/src/layers/basic.jl:202
Here is what happened, the most recent locations are first:
macro expansion
from
basic.jl:68
_applychain(layers::Tuple{β¦}, x::Float32) ...show types...
from
Flux β basic.jl:68
Chain(x::Float32) ...show types...
from
Flux β basic.jl:65
f2(u::Float32, p::Vector{Float32}, t::Float32)
from
This cell: line 4
m = re(p)
return 0.98*m(t').*u
end
ODEFunction
from
scimlfunctions.jl:2468
ode_determine_initdt(u0::Float32, t::Float32, tdir::Float32, dtmax::Float32, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SciMLBase.ODEProblem{β¦}, integrator::OrdinaryDiffEqCore.ODEIntegrator{β¦}) ...show types...
from
OrdinaryDiffEqCore β initdt.jl:248
auto_dt_reset!(integrator::OrdinaryDiffEqCore.ODEIntegrator{β¦}) ...show types...
from
OrdinaryDiffEqCore β integrator_interface.jl:424
handle_dt!(integrator::OrdinaryDiffEqCore.ODEIntegrator{β¦}) ...show types...
from
OrdinaryDiffEqCore β solve.jl:647
#__init#63(prob::SciMLBase.ODEProblem{β¦}, alg::OrdinaryDiffEqCore.CompositeAlgorithm{β¦}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{β¦}; saveat::Vector{β¦}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Float32, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{β¦}, abstol::Float64, reltol::Float64, qmin::Rational{β¦}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{β¦}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::SciMLBase.ODEAliasSpecifier, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{β¦}) ...show types...
from
OrdinaryDiffEqCore β solve.jl:609
__init
from
solve.jl:11
#__solve#62
from
solve.jl:6
__solve
from
solve.jl:1
#solve_call#35(_prob::SciMLBase.ODEProblem{β¦}, args::OrdinaryDiffEqCore.CompositeAlgorithm{β¦}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{β¦}) ...show types...
from
DiffEqBase β solve.jl:635
solve_call
from
solve.jl:592
#solve_up#44
from
solve.jl:1128
solve_up
from
solve.jl:1106
#solve#42
from
solve.jl:1045
solve
from
solve.jl:1033
#__solve#3
from
default_alg.jl:48
__solve
from
default_alg.jl:47
#__solve#63
from
solve.jl:1437
__solve
from
solve.jl:1428
#solve_call#35(::SciMLBase.ODEProblem{β¦}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{β¦}) ...show types...
from
DiffEqBase β solve.jl:635
solve_call
from
solve.jl:592
#solve_up#44(::SciMLBase.ODEProblem{β¦}, ::Nothing, ::Float32, ::Vector{β¦}; kwargs::@Kwargs{β¦}) ...show types...
from
DiffEqBase β solve.jl:1112
solve_up
from
solve.jl:1106
#solve#42(::SciMLBase.ODEProblem{β¦}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{β¦}, kwargs::@Kwargs{β¦}) ...show types...
from
DiffEqBase β solve.jl:1043
eval_model(model::Flux.Chain{β¦}, t::Vector{β¦}, p::Vector{β¦}) ...show types...
from
This cell: line 10
prob = DifferentialEquations.ODEProblem(f2,u0,tspan,p)
sol = DifferentialEquations.solve(prob,abstol=1e-8,reltol=1e-8,saveat=t)
return Array(sol.u)
from
This cell: line 29
p, re = Flux.destructure(model)
eval_model(model, t, p)
# m = re(p)
Iβve seen working examples with neural networks being used inside functions for ODEsolve, and itβs not clear where Iβm mis-stepping here.