# Understanding Flux and Zygote : Mutating errors not supported and a new error

Hello,

I am having some trouble about understanding the backward pass in Flux using Zygote. I have a model with a user defined loss function. The problem that I am facing is that when calculating the the loss the output of the model has to be converted first to a vector.

In order to do that I wrote the below code : (for the sake of reproducibility I change original function with a built-in functions )

``````using Statistics: mean
using LinearAlgebra: norm
using Zygote

function wsdrLoss(x, ŷ, y, eps=1e-8)

w, h, c, n =  size(y)

# x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> istft <- this is original
x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
ŷ = eachslice(reshape(ŷ, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
y = eachslice(reshape(y, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect

z = x .- y
ẑ = x .- ŷ

nd  = sum(reduce(hcat, y) .^2, dims=1)
dom = sum(reduce(hcat, z) .^2, dims=1)

aux  = nd ./ (nd .+ dom .+ eps)
wSDR = aux .* sdr(ŷ, y) .+ (1 .- aux) .* sdr(ẑ, z)
mean(wSDR)
end

function sdr(ypred, ygold; eps=1e-8)
num = sum(reduce(hcat, ygold) .* reduce(hcat, ypred), dims=1)
den = norm.(ygold, 2) .* norm.(ypred, 2)
-(num ./ (den'  .+ eps))
end

x = rand(ComplexF32, (513, 321, 1, 2));
y = rand(ComplexF32, (513, 321, 1, 2));
ŷ = rand(ComplexF32, (513, 321, 1, 2));

g = gradient(wsdrLoss, x, ŷ, y)

``````

The code above accepts 3 arguments : the original input, predicted output and the reference output.
Each argument is composed of 2 samples, and each `x` has `513x321x1x2` dimensions. And each sample inside `x` as to be converted first into a matrix then the `istft` function will convert each to a vector.

If I understand correctly Zygote does not like variable mutations hence it complains about the `.|> ` syntax.

For the second try, I tried to give only one sample at a time. To do that I changed the above code to :

``````using Zygote
using LinearAlgebra: norm
using Statistics: mean

function wsdrLoss2(x, ŷ, y, eps=1e-8)

w, h, c, n =  size(y)

x = reshape(x, w, h)  |> imag |> Iterators.flatten |> collect
ŷ = reshape(ŷ, w, h)  |> imag |> Iterators.flatten |> collect
y = reshape(y, w, h)  |> imag |> Iterators.flatten |> collect

z = x - y
ẑ = x - ŷ

nd  = sum(y .^2)
dom = sum(z .^2)
aux = nd / (nd + dom + eps)
wSDR = aux * sdr(ŷ, y) + (1 - aux) * sdr(ẑ, z)

end

function sdr(ypred, ygold; eps=1e-8)
num = sum(ygold .* ypred)
den = norm(ygold, 2) .* norm(ypred, 2)
-(num / (den  + eps))
end

x = rand(ComplexF32, (513, 321, 1, 1)); # reducing sample size from 2 to 1.
y = rand(ComplexF32, (513, 321, 1, 1));
ŷ = rand(ComplexF32, (513, 321, 1, 1));

g = gradient(wsdrLoss2, x , ŷ , y )

``````

But this time I get another error that is :

``````ERROR: MethodError: no method matching reshape(::IRTools.Inner.Undefined, ::Int64, ::Int64)
Closest candidates are:
reshape(::FillArrays.AbstractFill, ::Union{Colon, Int64}...) at /opt/.julia/packages/FillArrays/NjFh2/src/FillArrays.jl:206
reshape(::OffsetArrays.OffsetArray, ::Union{Colon, Int64}...) at /opt/.julia/packages/OffsetArrays/ExQCD/src/OffsetArrays.jl:240
reshape(::AbstractArray, ::Int64...) at reshapedarray.jl:116
...
Stacktrace:
 wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:161 [inlined]
 _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}, ::Float64) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:153 [inlined]
 _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 loss at ./REPL:1 [inlined]
 _pullback(::Zygote.Context, ::typeof(loss), ::Array{Complex{Float32},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 #5 at ./REPL:5 [inlined]
 _pullback(::Zygote.Context, ::var"#5#6"{typeof(loss),Array{Complex{Float32},4},Array{Complex{Float32},4}}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 pullback at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:172 [inlined]
 top-level scope at REPL:1
``````

Could anyone help please ? I am doing sth wrong but what ?

B.R.

Do you have to use slices? What you have written could be done with operations like

``````    nd = vec(sum(val -> imag(val)^2, y; dims=(1,2,3)))
dom = vec(sum(@. (imag(x) - imag(y))^2; dims=(1,2,3)))
``````

which are likely to be both faster and easier for Zygote to digest. Can this `istft` function be made to accept a 3-array, not just a matrix?

Your second example works if you delete ` |> Iterators.flatten |> collect`, which isn’t doing anything at all (except confusing Zygote) I think.

2 Likes

Probably ` istft` can be converted to accept 3-array. So you suggest to convert the x,y and y^ directly and obtain new values inside new variable?
Didn’t know that inside sum `-> ` operator could be used. I really appreciate Is there any resource where I can learn this kind of patterns ?
The `sum(f, x)` pattern may be a bit of a waste, as IIRC Zygote turns it back into `sum(f.(x))` (whereas normally it would save allocating `f.(x)`).
I think so? I mean I suggest working with “solid” arrays as much as possible, rather than making slices. So instead of (or in addition to) `istft(x::AbstractMatrix)` you define `istft(xs::AbstractArray{T,3})` with every step inside this function keeping track of one more dimension.