Neural SDE example no method matching error

Hi, I was following this tutorial for NeuralSDEs:

I tried to compile in Visual Studio and reached it until this moment:

using Plots, Flux, DiffEqFlux, DifferentialEquations, StochasticDiffEq, DiffEqBase.EnsembleAnalysis, Random
using Statistics

u0 = Float64[2. ; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length = datasize)

function trueSDEfunc(du, u, p , t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)‘true_A)’

mp = Float32[0.2, 0.2]
function true_noise_func(du, u, p, t)
du .= mp.*u

prob = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)

ensemble_prob = EnsembleProblem(prob)
ensemble_sol = solve(ensemble_prob,SOSRI(),trajectories = 10000)
ensemble_sum = EnsembleSummary(ensemble_sol)

sde_data,sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol,t))

drift_dudt = Chain(x → x.^3,
diffusion_dudt = Chain(Dense(2,2))
n_sde = NeuralDSDE(drift_dudt,diffusion_dudt,tspan,SOSRI(),saveat=t,reltol=1e-1,abstol=1e-1)

pred = n_sde(u0)

drift_(u, p, t) = drift_dudt(u, p[1:n_sde.len])
diffusion_(u, p, t) = diffusion_dudt(u, p[(n_sde.len+1):end])

prob_n_sde = SDEProblem(drift_, diffusion_, u0, (0.0f0, 1.0f0) , n_sde.p)

ensemble_nprob = EnsembleProblem(prob_n_sde)
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100, saveat = t)
ensemble_nsum = EnsembleSummary(ensemble_nsol)

Then got the follwowing error:

LoadError: MethodError: no method matching (::Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})(::Vector{Float64}, ::Vector{Float32})
Closest candidates are:
(::Chain)(::Any) at C:\Users\User.julia\packages\Flux\EHgZm\src\layers\basic.jl:51

I spent few days on this error but still wasn’t able to fix the issue. Any help will be appreciated.

Thank you!

I didn’t test this, but it looks like the example you linked uses FastChain from SimpleChains.jl and you used Chain from Flux.jl.

The documentation on it is here: Neural Stochastic Differential Equations With Method of Moments · DiffEqFlux.jl. I highly recommend you use that instead of the link you found.

