Hello,
I am having some trouble about understanding the backward pass in Flux using Zygote. I have a model with a user defined loss function. The problem that I am facing is that when calculating the the loss the output of the model has to be converted first to a vector.
In order to do that I wrote the below code : (for the sake of reproducibility I change original function with a built-in functions )
using Statistics: mean
using LinearAlgebra: norm
using Zygote
function wsdrLoss(x, ŷ, y, eps=1e-8)
w, h, c, n = size(y)
# x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> istft <- this is original
x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
ŷ = eachslice(reshape(ŷ, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
y = eachslice(reshape(y, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
z = x .- y
ẑ = x .- ŷ
nd = sum(reduce(hcat, y) .^2, dims=1)
dom = sum(reduce(hcat, z) .^2, dims=1)
aux = nd ./ (nd .+ dom .+ eps)
wSDR = aux .* sdr(ŷ, y) .+ (1 .- aux) .* sdr(ẑ, z)
mean(wSDR)
end
function sdr(ypred, ygold; eps=1e-8)
num = sum(reduce(hcat, ygold) .* reduce(hcat, ypred), dims=1)
den = norm.(ygold, 2) .* norm.(ypred, 2)
-(num ./ (den' .+ eps))
end
x = rand(ComplexF32, (513, 321, 1, 2));
y = rand(ComplexF32, (513, 321, 1, 2));
ŷ = rand(ComplexF32, (513, 321, 1, 2));
g = gradient(wsdrLoss, x, ŷ, y)
The code above accepts 3 arguments : the original input, predicted output and the reference output.
Each argument is composed of 2 samples, and each x
has 513x321x1x2
dimensions. And each sample inside x
as to be converted first into a matrix then the istft
function will convert each to a vector.
If I understand correctly Zygote does not like variable mutations hence it complains about the .|>
syntax.
For the second try, I tried to give only one sample at a time. To do that I changed the above code to :
using Zygote
using LinearAlgebra: norm
using Statistics: mean
function wsdrLoss2(x, ŷ, y, eps=1e-8)
w, h, c, n = size(y)
x = reshape(x, w, h) |> imag |> Iterators.flatten |> collect
ŷ = reshape(ŷ, w, h) |> imag |> Iterators.flatten |> collect
y = reshape(y, w, h) |> imag |> Iterators.flatten |> collect
z = x - y
ẑ = x - ŷ
nd = sum(y .^2)
dom = sum(z .^2)
aux = nd / (nd + dom + eps)
wSDR = aux * sdr(ŷ, y) + (1 - aux) * sdr(ẑ, z)
end
function sdr(ypred, ygold; eps=1e-8)
num = sum(ygold .* ypred)
den = norm(ygold, 2) .* norm(ypred, 2)
-(num / (den + eps))
end
x = rand(ComplexF32, (513, 321, 1, 1)); # reducing sample size from 2 to 1.
y = rand(ComplexF32, (513, 321, 1, 1));
ŷ = rand(ComplexF32, (513, 321, 1, 1));
g = gradient(wsdrLoss2, x , ŷ , y )
But this time I get another error that is :
ERROR: MethodError: no method matching reshape(::IRTools.Inner.Undefined, ::Int64, ::Int64)
Closest candidates are:
reshape(::FillArrays.AbstractFill, ::Union{Colon, Int64}...) at /opt/.julia/packages/FillArrays/NjFh2/src/FillArrays.jl:206
reshape(::OffsetArrays.OffsetArray, ::Union{Colon, Int64}...) at /opt/.julia/packages/OffsetArrays/ExQCD/src/OffsetArrays.jl:240
reshape(::AbstractArray, ::Int64...) at reshapedarray.jl:116
...
Stacktrace:
[1] adjoint at /opt/.julia/packages/Zygote/xBjHw/src/lib/array.jl:93 [inlined]
[2] _pullback at /opt/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[3] wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:161 [inlined]
[4] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}, ::Float64) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[5] wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:153 [inlined]
[6] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[7] loss at ./REPL[29]:1 [inlined]
[8] _pullback(::Zygote.Context, ::typeof(loss), ::Array{Complex{Float32},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[9] #5 at ./REPL[31]:5 [inlined]
[10] _pullback(::Zygote.Context, ::var"#5#6"{typeof(loss),Array{Complex{Float32},4},Array{Complex{Float32},4}}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
[11] pullback at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:172 [inlined]
[12] my_custom_train!(::typeof(loss), ::UNet, ::DataLoaders.BufferGetObsParallel{Tuple{Array{Complex{Float32},4},Array{Complex{Float32},4}},DataLoaders.BatchViewCollated{Tuple{Data,Data}}}, ::ADAM) at ./REPL[31]:5
[13] top-level scope at REPL[32]:1
Could anyone help please ? I am doing sth wrong but what ?
B.R.