I have a problem with Flux. Unfortunately Zygote cannot derive my loss function and I am not sure what it is. It would be nice if someone could help me with this problem.
My loss function looks like this:
function loss_ungelabelt2(X,Y)
a_list = Float32[]
b_list = Float32[]
mix_res = Float32[]
λ::Float32 = 0.3
loss::Float32 = 0
for i in 1:2 #eigentlich size(X,4)
a = X[:,:,:,rand(1:size(X,4))]
b = X[:,:,:,rand(1:size(X,4))]
mix = λ .* a .+ (1 - λ) .* b
println(length(mix))
println(size(mix))
println("NEU")
mix_res=reshape([mix_res...,mix...],size(mix)...,:) # einzelne Elemente
a_list=reshape([a_list...,a...],size(a)...,:)
b_list=reshape([b_list...,b...],size(b)...,:)
println(size(mix_res))
println("neueRunde")
end
y_mix = model(mix_res)
y_a = model(a_list)
y_b = model(b_list)
println("a", size(y_a))
println("b", size(y_b))
y_mix_ab = λ .* y_a .+ (1 - λ) .* y_b
println(length(y_mix_ab))
quadrat_loss = sum(abs2,y_mix .- y_mix_ab)
return quadrat_loss
end
I get this error:
MethodError: no method matching size(::NTuple{784,Float32})
Closest candidates are:
size(::Tuple, !Matched::Integer) at tuple.jl:22
size(!Matched::Flux.OneHotVector) at /Users/lisa/.julia/packages/Flux/IjMZL/src/onehot.jl:8
size(!Matched::ZMQ.Message) at /Users/lisa/.julia/packages/ZMQ/R3wSD/src/message.jl:95
...
Stacktrace:
[1] unbroadcast(::Array{Float32,3}, ::NTuple{784,Float32}) at /Users/lisa/.julia/packages/Zygote/seGHk/src/lib/broadcast.jl:53
[2] (::Zygote.var"#1107#1109"{NTuple{784,Float32}})(::Array{Float32,3}) at /Users/lisa/.julia/packages/Zygote/seGHk/src/lib/broadcast.jl:74
[3] map(::Zygote.var"#1107#1109"{NTuple{784,Float32}}, ::Tuple{Array{Float32,3},Array{Float32,3}}) at ./tuple.jl:158
[4] (::Zygote.var"#1106#1108"{Tuple{Array{Float32,3},Array{Float32,3}}})(::NTuple{784,Float32}) at /Users/lisa/.julia/packages/Zygote/seGHk/src/lib/broadcast.jl:74
[5] (::Zygote.var"#3852#back#1110"{Zygote.var"#1106#1108"{Tuple{Array{Float32,3},Array{Float32,3}}}})(::NTuple{784,Float32}) at /Users/lisa/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[6] loss_ungelabelt2 at ./In[13]:14 [inlined]
[7] (::typeof(∂(loss_ungelabelt2)))(::Float32) at /Users/lisa/.julia/packages/Zygote/seGHk/src/compiler/interface2.jl:0
[8] (::Zygote.var"#41#42"{typeof(∂(loss_ungelabelt2))})(::Float32) at /Users/lisa/.julia/packages/Zygote/seGHk/src/compiler/interface.jl:45
[9] gradient(::Function, ::Array{Float32,4}, ::Vararg{Any,N} where N) at /Users/lisa/.julia/packages/Zygote/seGHk/src/compiler/interface.jl:54
[10] top-level scope at In[14]:1
[11] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091
I would be very grateful for an answer.