Another error with defining Flux.jl gradients

first-steps
flux

#1

Here is the MWE

x = rand(12, 1000)
y = rand(1000)
W = param([1 for i=1:12])
a = param([0.5])
b = param([0.5])

model(x) = begin
    x1 = [maximum(x[:,j] .* W) for j in 1:size(x, 2)]
    (a .+ (x1 .- StatsBase.mean(x1))./StatsBase.std(x1) .* b)
end

loss(x,y) = sum((model(x) .- y).^2)

model(x) # works
loss(x, y) # works

params = params = Flux.Params([W,a,b])
grads = Tracker.gradient(() -> loss(devx, devy), params)

which gives this error.

MethodError: no method matching Float64(::Flux.Tracker.TrackedReal{Float64})
Closest candidates are:
Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:185
Float64(::T<:Number) where T<:Number at boot.jl:725
Float64(!Matched::Int8) at float.jl:60

Stacktrace:
[1] convert(::Type{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at ./number.jl:7
[2] setindex!(::Array{Float64,1}, ::Flux.Tracker.TrackedReal{Float64}, ::Int64) at ./array.jl:769
[3] (::getfield(Flux.Tracker, Symbol("##319#320")){Colon,TrackedArray{…,Array{Float64,1}}})(::Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:257
[4] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##319#320")){Colon,TrackedArray{…,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:103
[5] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:116
[6] #4 at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106 [inlined]
[7] foreach at ./abstractarray.jl:1836 [inlined]
[8] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##202#203")),Tuple{Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[9] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
… (the last 4 lines are repeated 998 more times)
[4002] #4 at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106 [inlined]
[4003] foreach at ./abstractarray.jl:1836 [inlined]
[4004] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##194#195")){Flux.Tracker.TrackedReal{Float64},Int64},Tuple{Flux.Tracker.Tracked{Float64},Nothing}}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4005] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:116
[4006] #4 at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106 [inlined]
[4007] foreach at ./abstractarray.jl:1836 [inlined]
[4008] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##230#231")),Tuple{Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4009] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4010] foreach at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106 [inlined]
[4011] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#353")){1,typeof(abs2),Tuple{Flux.Tracker.TrackedReal{Float64}}},Tuple{Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4012] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4013] #4 at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106 [inlined]
[4014] foreach at ./abstractarray.jl:1836 [inlined]
… (the last 4 lines are repeated 2 more times)
[4023] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#353")){1,typeof(sqrt),Tuple{Flux.Tracker.TrackedReal{Float64}}},Tuple{Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4024] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4025] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4026] foreach(::Function, ::Tuple{Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Array{Float64,1}}}, ::Tuple{TrackedArray{…,Array{Flux.Tracker.TrackedReal{Float64},1}},TrackedArray{…,Array{Flux.Tracker.TrackedReal{Float64},1}},Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}},Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}},TrackedArray{…,Array{Flux.Tracker.TrackedReal{Float64},1}}}) at ./abstractarray.jl:1836
… (the last 4 lines are repeated 1 more time)
[4031] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#353")){4,getfield(Base.Broadcast, Symbol("##26#28")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##27#29")){typeof(-),getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##11#12"))}},getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##15#16"))}},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##3#4"))}}}}},typeof(Base.literal_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{…,Array{Flux.Tracker.TrackedReal{Float64},1}},Array{Float64,1},Base.RefValue{Val{2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Flux.Tracker.TrackedReal{Float64},1}},Nothing,Nothing}}, ::Array{Flux.Tracker.TrackedReal{Float64},1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4032] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Flux.Tracker.TrackedReal{Float64},1}}, ::Array{Flux.Tracker.TrackedReal{Float64},1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4033] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Flux.Tracker.TrackedReal{Float64},1}}, ::Array{Flux.Tracker.TrackedReal{Float64},1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4034] foreach at ./abstractarray.jl:1836 [inlined]
[4035] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##299#300")){TrackedArray{…,Array{Flux.Tracker.TrackedReal{Float64},1}}},Tuple{Flux.Tracker.Tracked{Array{Flux.Tracker.TrackedReal{Float64},1}}}}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[4036] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Flux.Tracker.TrackedReal{Float64}}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4037] (::getfield(Flux.Tracker, Symbol("##6#7")){Params,Flux.Tracker.TrackedReal{Flux.Tracker.TrackedReal{Float64}}})(::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:131
[4038] gradient(::Function, ::Params) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:152


#2

This isn’t runable yet. I’m trying

using Flux, StatsBase

x = rand(12, 1000)
y = rand(1000)
W = param([1 for i=1:12])
a = param([0.5])
b = param([0.5])

model(x) = begin
    x1 = [maximum(x[:,j] .* W) for j in 1:size(x, 2)]
    (a .+ (x1 .- StatsBase.mean(x1))./StatsBase.std(x1) .* b)
end

model(x) # works
loss(x, y) # works

params = params = Flux.Params([W,a,b])
grads = Tracker.gradient(() -> loss(devx, devy), params)

but what’s your loss function?


#3

Oops. Updates


#4

I couldn’t figure out the root cause of the error, but I think I figured out a way to rewrite the code such that it works:

using Flux
using StatsBase

x = rand(12, 1000)
y = rand(1000)
W = param([1 for i=1:12])
a = param([0.5])
b = param([0.5])

model(x) = begin
    prod = x .* W
    x1 = maximum(prod, dims=1)[1, :]
    (a .+ (x1 .- StatsBase.mean(x1))./StatsBase.std(x1) .* b)
end

loss(x,y) = sum((model(x) .- y).^2)

model(x) # works
loss(x, y) # works

params = params = Flux.Params([W,a,b])
grads = Tracker.gradient(() -> loss(x, y), params)

#5

How did you figure this out? So if we ran into similar problems what are the things we should look at?


#6

I figured that the comprehension over vectors with turns into a matrix would be the trickiest part for the backwards pass, so I tried removing it and found that it worked. I guessed this based off the stack trace, which shows the error comes from trying to write a TrackedReal into an array of Floats.

I’m a Julia novice, so my intuition above is probably bad.