I am trying to define some simple custom graidents in Flux.jl but am running into issues. Here is an MWE
# start random parameters
W = param(rand(12))
model(x) = begin
x*W
end
loss(x, y) = sum((model(x) .- y).^2)
x = rand(12)
y = rand(12)
params = Flux.Params([W])
grads = Flux.gradient(() -> loss(x,y), params)
gives error which I canβt seem to understand. I donβt se how I am doing anything different to the [documentation here]https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/basics.md)
MethodError: *(::LinearAlgebra.Transpose{Float64,Array{Float64,2}}, ::TrackedArray{β¦,Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:281
*(transA::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/LinearAlgebra/src/matmul.jl:83
Possible fix, define
*(::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, ::TrackedArray{S,1,A} where A)Stacktrace:
[1] (::getfield(Flux.Tracker, Symbol(β##326#327β)){Array{Float64,2},TrackedArray{β¦,Array{Float64,1}}})(::TrackedArray{β¦,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:289
[2] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol(β##326#327β)){Array{Float64,2},TrackedArray{β¦,Array{Float64,1}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{β¦,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:103
[3] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{β¦,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4] (::getfield(Flux.Tracker, Symbol(β##4#5β)){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{β¦,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[5] foreach(::Function, ::Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}, ::Tuple{Flux.Tracker.Tracked{Nothing},TrackedArray{β¦,Array{Float64,1}},TrackedArray{β¦,Array{Float64,1}},Flux.Tracker.Tracked{Nothing}}) at ./abstractarray.jl:1836
[6] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol(β#back#353β)){4,getfield(Base.Broadcast, Symbol(β##26#28β)){getfield(Base.Broadcast, Symbol(β##5#6β)){getfield(Base.Broadcast, Symbol(β##27#29β)){typeof(-),getfield(Base.Broadcast, Symbol(β##9#10β)){getfield(Base.Broadcast, Symbol(β##9#10β)){getfield(Base.Broadcast, Symbol(β##11#12β))}},getfield(Base.Broadcast, Symbol(β##13#14β)){getfield(Base.Broadcast, Symbol(β##13#14β)){getfield(Base.Broadcast, Symbol(β##15#16β))}},getfield(Base.Broadcast, Symbol(β##5#6β)){getfield(Base.Broadcast, Symbol(β##5#6β)){getfield(Base.Broadcast, Symbol(β##5#6β)){getfield(Base.Broadcast, Symbol(β##3#4β))}}}}},typeof(Base.literal_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{β¦,Array{Float64,1}},Array{Float64,1},Base.RefValue{Val{2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[7] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[8] (::getfield(Flux.Tracker, Symbol(β##4#5β)){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[9] foreach at ./abstractarray.jl:1836 [inlined]
[10] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol(β##299#300β)){TrackedArray{β¦,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[11] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[12] (::getfield(Flux.Tracker, Symbol(β##6#7β)){Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}})(::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:131
[13] gradient(::Function, ::Flux.Tracker.Params) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:152
[14] top-level scope at In[139]:1