I am trying to define some simple custom graidents in Flux.jl but am running into issues. Here is an MWE
# start random parameters
W = param(rand(12))
model(x) = begin
x*W
end
loss(x, y) = sum((model(x) .- y).^2)
x = rand(12)
y = rand(12)
params = Flux.Params([W])
grads = Flux.gradient(() -> loss(x,y), params)
gives error which I can’t seem to understand. I don’t se how I am doing anything different to the [documentation here]https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/basics.md)
MethodError: *(::LinearAlgebra.Transpose{Float64,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::TrackedArray{T,1,A} where A where T) in Flux.Tracker at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:281
*(transA::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, x::AbstractArray{S,1}) where {T, S} in LinearAlgebra at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.0/LinearAlgebra/src/matmul.jl:83
Possible fix, define
*(::LinearAlgebra.Transpose{#s549,#s548} where #s548<:AbstractArray{T,2} where #s549, ::TrackedArray{S,1,A} where A)Stacktrace:
[1] (::getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}})(::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/array.jl:289
[2] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##326#327")){Array{Float64,2},TrackedArray{…,Array{Float64,1}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}}}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:103
[3] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[4] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[5] foreach(::Function, ::Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}, ::Tuple{Flux.Tracker.Tracked{Nothing},TrackedArray{…,Array{Float64,1}},TrackedArray{…,Array{Float64,1}},Flux.Tracker.Tracked{Nothing}}) at ./abstractarray.jl:1836
[6] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("#back#353")){4,getfield(Base.Broadcast, Symbol("##26#28")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##27#29")){typeof(-),getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##9#10")){getfield(Base.Broadcast, Symbol("##11#12"))}},getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##13#14")){getfield(Base.Broadcast, Symbol("##15#16"))}},getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##3#4"))}}}}},typeof(Base.literal_pow)},Tuple{Base.RefValue{typeof(^)},TrackedArray{…,Array{Float64,1}},Array{Float64,1},Base.RefValue{Val{2}}}},Tuple{Nothing,Flux.Tracker.Tracked{Array{Float64,1}},Nothing,Nothing}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[7] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[8] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Array{Float64,1}}, ::Array{Float64,1}) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[9] foreach at ./abstractarray.jl:1836 [inlined]
[10] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##299#300")){TrackedArray{…,Array{Float64,1}}},Tuple{Flux.Tracker.Tracked{Array{Float64,1}}}}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:106
[11] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:118
[12] (::getfield(Flux.Tracker, Symbol("##6#7")){Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}})(::Int64) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:131
[13] gradient(::Function, ::Flux.Tracker.Params) at /home/jrun/.julia/packages/Flux/UHjNa/src/tracker/back.jl:152
[14] top-level scope at In[139]:1