Flux.jl MethodError ... (::Dense{typeof(identity),... on simple Chain model

Hi, I’m new to using Flux and I keep getting a data type error when running the following code:

using LinearAlgebra, Flux, Base.Iterators, Statistics
using CuArrays
using Flux: onehotbatch, onecold, crossentropy, throttle, Tracker

trainX = Tracker.data(rand(19,348)) |> gpu
trainY = Tracker.data(rand(19,348)) |> gpu

m = Chain(
  Dense(19,32),
  #LSTM(32,32),
  Dense(32,19),
  softmax) |> gpu

loss(x, y) = Flux.mse(m(x), y)

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

dataset = zip(trainX, trainY)
evalcb = () -> @show(loss(trainX, trainY))
opt = ADAM()

Flux.train!(loss, params(m), dataset, opt)

Running train! results in:

MethodError: no method matching (::Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}})(::Float64)
Closest candidates are:
  Dense(!Matched::AbstractArray{T<:Union{Float32, Float64},N} where N) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /home/m_tucci/.julia/packages/Flux/dkJUV/src/layers/basic.jl:110
  Dense(!Matched::AbstractArray{#s107,N} where N where #s107<:AbstractFloat) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /home/m_tucci/.julia/packages/Flux/dkJUV/src/layers/basic.jl:113
  Dense(!Matched::AbstractArray) at /home/m_tucci/.julia/packages/Flux/dkJUV/src/layers/basic.jl:98
applychain(::Tuple{Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},typeof(softmax)}, ::Float64) at basic.jl:31
(::Chain{Tuple{Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Dense{typeof(identity),TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},typeof(softmax)}})(::Float64) at basic.jl:33
loss(::Float64, ::Float64) at Player_Flux.jl:45
#14 at train.jl:72 [inlined]
gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss),Tuple{Float64,Float64}}, ::Tracker.Params) at back.jl:97
#gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at back.jl:164
gradient at back.jl:164 [inlined]
macro expansion at train.jl:71 [inlined]
macro expansion at progress.jl:119 [inlined]
#train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at train.jl:69
train!(::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Float64,2},Array{Float64,2}}}, ::ADAM) at train.jl:67
top-level scope at Player_Flux.jl:53
include_string(::Module, ::String, ::String, ::Int64) at eval.jl:30
(::getfield(Atom, Symbol("##124#129")){String,Int64,String})() at eval.jl:94
withpath(::getfield(Atom, Symbol("##124#129")){String,Int64,String}, ::String) at utils.jl:30
withpath at eval.jl:46 [inlined]
#123 at eval.jl:93 [inlined]
with_logstate(::getfield(Atom, Symbol("##123#128")){String,Int64,String}, ::Base.CoreLogging.LogState) at logging.jl:395
with_logger at logging.jl:491 [inlined]
#122 at eval.jl:92 [inlined]
hideprompt(::getfield(Atom, Symbol("##122#127")){String,Int64,String}) at repl.jl:77
macro expansion at eval.jl:91 [inlined]
macro expansion at dynamic.jl:24 [inlined]
(::getfield(Atom, Symbol("##121#126")))(::Dict{String,Any}) at eval.jl:86
handlemsg(::Dict{String,Any}, ::Dict{String,Any}) at comm.jl:164
(::getfield(Atom, Symbol("##19#21")){Array{Any,1}})() at task.jl:268

I feel like I’m missing something obvious here but I’m too new to this to figure it out. I found another thread that seemed similar and suggested adding Tracker.data() to training data sets, but that (obviously) didn’t work.
Any suggestions?

If you look at first(dataset), you’ll see a tuple of numbers. And if you put those into loss, or just take m(0.5), then you get the same error — Dense wants to act not on a number but on a matrix.

What you need is something of the form dataset = [ (trainX, trainY), (trainX, trainY) ], so that each element of dataset forms good arguments for loss.

Mine explicitly copies the data twice, which you can avoid by calling some iterator instead, as some examples do. But it may be clearer just to write out the loop. The source code of train! is clearer than the function, really.

You’re definitely right. I got a chain with Dense layers to work but when I uncomment LSTM I get the following:

`back!` was already used
back_(::Tracker.Call{Missing,Tuple{}}, ::Array{Float64,1}, ::Bool) at back.jl:42
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
#13 at back.jl:38 [inlined]
foreach at abstractarray.jl:1921 [inlined]
back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){4,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(*)},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(*)},typeof(+)},NTuple{4,TrackedArray{…,Array{Float32,1}}}},NTuple{4,Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,1}, ::Bool) at back.jl:38
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
#13 at back.jl:38 [inlined]
foreach at abstractarray.jl:1921 [inlined]
back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))},typeof(tanh)}},typeof(*)},Tuple{TrackedArray{…,Array{Float32,1}},TrackedArray{…,Array{Float32,1}}}},Tuple{Tracker.Tracked{Array{Float32,1}},Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,1}, ::Bool) at back.jl:38
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
(::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}) at back.jl:38
foreach(::Function, ::Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,1}}}, ::Tuple{Array{Float64,2},Array{Float64,1}}) at abstractarray.jl:1921
back_(::Tracker.Call{getfield(Tracker, Symbol("##509#510")){TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}},Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,1}, ::Bool) at back.jl:38
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
#13 at back.jl:38 [inlined]
foreach at abstractarray.jl:1921 [inlined]
back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(+)},typeof(identity)},Tuple{TrackedArray{…,Array{Float32,1}},TrackedArray{…,Array{Float32,1}}}},Tuple{Tracker.Tracked{Array{Float32,1}},Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,1}, ::Bool) at back.jl:38
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
foreach at back.jl:38 [inlined]
back_(::Tracker.Call{getfield(Tracker, Symbol("##511#512")){TrackedArray{…,Array{Float32,1}}},Tuple{Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float64,1}, ::Bool) at back.jl:38
back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float64,1}, ::Bool) at back.jl:58
#13 at back.jl:38 [inlined]
foreach at abstractarray.jl:1921 [inlined]
back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){4,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##1#3"))},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base....

I am getting a similar back! was already used error.
Were you able to figure this out?