Flux.train error: no method matching back!(::Float64)

I am working with Julia 1.2.0 on Jupyter lab on Windows 10 machine.
I am using the following package versions:
DiffEqFlux : v0.8.1
DifferentialEquations: v6.8.0
Flux: v0.9.0

I am trying to train a neural network using a differential equation optimizer using the DiffEqFlux package.
Here is my code:

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots
using Flux: @epochs

cr = 32
Nt = 512
T_cs = zeros(Nt, cr)
T₀ = T_cs[:, 1]

ann = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

p1 = Flux.data(DiffEqFlux.destructure(ann))
ps = Flux.params(ann)

foretold(u,p,t) = DiffEqFlux.restructure(ann,p[1:4192])(u)
prob = ODEProblem(foretold,T₀,tspan_train,p1)
Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6))

function predict_adjoint()
  diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8)
end

opt = ADAM(1)
data = [(T₀, T_cs[:, 1:n_train])]
loss_function(T₀, T_data) = sum((predict_adjoint() .- T_data).^2)
cb = function ()
    loss = loss_function(T₀, T_cs[:, 1:n_train]) # Not very generalizable...
    println("loss = $loss")
end

for _ in 1:100
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

I get the following error when i run this:

MethodError: no method matching back!(::Float64)
Closest candidates are:
  back!(::Any, !Matched::Any; once) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\back.jl:75
  back!(!Matched::Tracker.TrackedReal; once) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\lib\real.jl:14
  back!(!Matched::TrackedArray) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\lib\array.jl:68

Stacktrace:
 [1] gradient_(::getfield(Flux.Optimise, Symbol("##15#21")){typeof(loss_function),Tuple{Array{Float64,1},Array{Float64,2}}}, ::Tracker.Params) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\back.jl:4
 [2] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\back.jl:164
 [3] gradient at C:\Users\daddyj\.julia\packages\Tracker\JhqMQ\src\back.jl:164 [inlined]
 [4] macro expansion at C:\Users\daddyj\.julia\packages\Flux\dkJUV\src\optimise\train.jl:71 [inlined]
 [5] macro expansion at C:\Users\daddyj\.julia\packages\Juno\oLB1d\src\progress.jl:134 [inlined]
 [6] #train!#12(::getfield(Main, Symbol("##27#28")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Array{Tuple{Array{Float64,1},Array{Float64,2}},1}, ::ADAM) at C:\Users\daddyj\.julia\packages\Flux\dkJUV\src\optimise\train.jl:69
 [7] (::getfield(Flux.Optimise, Symbol("#kw##train!")))(::NamedTuple{(:cb,),Tuple{getfield(Main, Symbol("##27#28"))}}, ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Array{Tuple{Array{Float64,1},Array{Float64,2}},1}, ::ADAM) at .\none:0
 [8] top-level scope at .\In[91]:2

I would appreciate any suggestions as to why I am getting this error.

Hey @janak I had to modify your example a little bit to get it to run (syntax errors and some indexing issues) but ended up running into the same error.

I’m not super familiar with Flux but I noticed that the loss function does not return a tracked scalar so I’m guessing it can’t backpropagate as Flux can only calculate gradients of tracked values. But not exactly sure which array or which value should be tracked…

using Flux
using DifferentialEquations
using DiffEqFlux
using LinearAlgebra
using JLD2
using Plots
using Flux: @epochs

cr = 32
Nt = 512
T_cs = zeros(Nt, cr)
T₀ = T_cs[1, :]

ann = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

p1 = Flux.data(DiffEqFlux.destructure(ann))
ps = Flux.params(ann)

n_train = 256
t_train = range(0.0, 1.0; length=n_train)
tspan_train = (0.0,1.0)

foretold(u,p,t) = DiffEqFlux.restructure(ann,p[1:4192])(u)
prob = ODEProblem(foretold,T₀,tspan_train,p1)
Flux.Tracker.collect(diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,abstol=1e-8,reltol=1e-6))

function predict_adjoint()
  diffeq_adjoint(p1,prob,Tsit5(),u0=T₀,saveat=t_train,reltol=1e-6, abstol=1e-8)
end

opt = ADAM(1)
data = [(T₀, T_cs[1:n_train, :])]
loss_function(T₀, T_data) = sum((hcat(predict_adjoint().u...)' .- T_data).^2)
cb = function ()
    loss = loss_function(T₀, T_cs[1:n_train, :]) # Not very generalizable...
    println("loss = $loss")
end

for _ in 1:100
    Flux.train!(loss_function, ps, data, opt, cb = cb)
end

@janak - I took your example and modified it a bit and it should work now without any errors. Not sure if I got the loss function to do what you intended you it to do but here is a working example.

using DiffEqFlux, Flux, OrdinaryDiffEq

cr = 2
Nt = 512
T_cs = zeros(Nt, cr)
T₀ = T_cs[1, :]
n_train = 251
u0 = param(Float32[0.8; 0.8])
tspan = (0.0f0,25.0f0)

ann = Chain(Dense(cr, 2cr, tanh),
                Dense(2cr, cr))

p1 = Flux.data(DiffEqFlux.destructure(ann))
p2 = Float32[-2.0,1.1]
p3 = param([p1;p2])
ps = Flux.params(p3,u0)


function dudt_(du,u,p,t)
    x, y = u
    du[1] = DiffEqFlux.restructure(ann,p[1:length(p)-2])(u)[1]
    du[2] = p[end-1]*y + p[end]*x
end

prob = ODEProblem(dudt_,u0,tspan,p3)
diffeq_adjoint(p3,prob,Tsit5(),u0=u0,abstol=1e-8,reltol=1e-6)


function predict_adjoint()
  diffeq_adjoint(p3,prob,Tsit5(),u0=u0,saveat=0.0:0.1:25.0)
end
opt = ADAM(1)
evalcb = () -> @show(loss(X, Y))
loss_adjoint(T,T_data) = sum(hcat(predict_adjoint()' .- T_data).^2)
data = [(T₀, T_cs[1:n_train, :])]
evalcb = () -> @show(loss_adjoint(T₀, T_cs[1:n_train, :]))

for _ in 1:100
    Flux.train!(loss_adjoint, ps, data, opt, cb = evalcb)
end