Error with defining customer gradients in Flux.jl

flux

#1

I am trying to define some simple custom graidents in Flux.jl but am running into issues. Here is an MWE

# start random parameters
W = param(rand(12))

model(x) = begin
    x*W
end

loss(x, y) = sum((model(x) .- y).^2)

x = rand(12)
y = rand(12)

params = Flux.Params([W])

grads = Flux.gradient(() -> loss(x,y), params)

gives error which I can’t seem to understand. I don’t se how I am doing anything different to the [documentation here]https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/basics.md)

MethodError: *(::LinearAlgebra.Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:281
*(transA::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/LinearAlgebra/src/matmul.jl:83
Possible fix, define
*(::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, ::TrackedArray{S,1,A} where A)

Stacktrace:
[1] (::getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}})(::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:289
[2] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:103
[3] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[5] foreach(::Function, ::Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}, ::Tuple{Flux.Tracker.Tracked{Nothing},TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}},Flux.Tracker.Tracked{Nothing}}) at ./abstractarray.jl:1836
[6] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#353")){4,getfield(Base.Broadcast, Symbol("##26#28")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##27#29")){typeof(-),getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##11#12"))}},getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##15#16"))}},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##3#4"))}}}}},typeof(Base.literal_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{…,Array{Float64,1}},Array{Float64,1},Base.RefValue{Val{2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[7] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[8] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[9] foreach at ./abstractarray.jl:1836 [inlined]
[10] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##299#300")){TrackedArray{…,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[11] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[12] (::getfield(Flux.Tracker, Symbol("##6#7")){Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}})(::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:131
[13] gradient(::Function, ::Flux.Tracker.Params) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:152
[14] top-level scope at In[139]:1


#2

Wildly guessing, but maybe x*W should be W*x?


#3

Nope, cos loss(x,y) works


#5

I can reproduce the error with the following:

k = 2
n = 12

W = param(rand(k))
model(x) = begin
    x*W
end

loss(x, y) = sum((model(x) .- y).^2)

x = rand(n, k)
y = rand(n)

params = Flux.Params([W])
grads = Flux.Tracker.gradient(() -> loss(x,y), params)

#6

The key point is also that loss(x,y) works even in your example . But when backing out the gradient, it doesn’t work. I am scratching my head still.


#7

Could be a genuine bug. I go report it


#8

Is that what you are looking to do?

W = param(rand(12))

predict(x) = W.*x


function loss(x, y)
   ŷ = predict(x)
    sum( (y - ŷ).^2 )
end

x = rand(12)
y = rand(12)


grads = Tracker.gradient(() -> loss(x,y), Params([W]))

#9

Notice that the example is one-dimensional. predict is a single predictor. Multiplication of W*x for scalars is defined, but not for arrays.

julia> W*x
ERROR: MethodError: no method matching _forward(::typeof(*), ::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1})
Closest candidates are:
  _forward(::typeof(*), ::AbstractArray{T,2} where T, ::Union{AbstractArray{T,1}, AbstractArray{T,2}} where T) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/array.jl:288
  _forward(::typeof(getindex), ::AbstractArray, ::Any...) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/array.jl:74
  _forward(::typeof(vcat), ::Any...) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/array.jl:136
  ...
Stacktrace:
 [1] #track#1(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::typeof(*), ::TrackedArray{…,Array{Float64,1}}, ::Vararg{Any,N} where N) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/Tracker.jl:50
 [2] track(::typeof(*), ::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/Tracker.jl:50
 [3] *(::TrackedArray{…,Array{Float64,1}}, ::Array{Float64,1}) at /Users/berend/.julia/packages/Flux/UHjNa/src/tracker/array.jl:284
 [4] top-level scope at none:0

Instead use W'*x.

julia> W'*x
3.1105292208823814 (tracked)

#10

Also your loss “worked” because the broadcast x .+ y adds a scalar constant to each element of an array independent of whether x or y is the array.


#11

Thanks but I want vector product not element wise product.