Understanding Flux and Zygote : Mutating errors not supported and a new error


I am having some trouble about understanding the backward pass in Flux using Zygote. I have a model with a user defined loss function. The problem that I am facing is that when calculating the the loss the output of the model has to be converted first to a vector.

In order to do that I wrote the below code : (for the sake of reproducibility I change original function with a built-in functions )

using Statistics: mean
using LinearAlgebra: norm
using Zygote

function wsdrLoss(x, ŷ, y, eps=1e-8)

    w, h, c, n =  size(y)

    # x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> istft <- this is original 
    x = eachslice(reshape(x, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
    ŷ = eachslice(reshape(ŷ, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect
    y = eachslice(reshape(y, w, h, n), dims=3) .|> Matrix .|> imag .|> Iterators.flatten .|> collect

    z = x .- y
    ẑ = x .- ŷ

    nd  = sum(reduce(hcat, y) .^2, dims=1)
    dom = sum(reduce(hcat, z) .^2, dims=1)
    aux  = nd ./ (nd .+ dom .+ eps)
    wSDR = aux .* sdr(ŷ, y) .+ (1 .- aux) .* sdr(ẑ, z) 

function sdr(ypred, ygold; eps=1e-8)
    num = sum(reduce(hcat, ygold) .* reduce(hcat, ypred), dims=1)
    den = norm.(ygold, 2) .* norm.(ypred, 2)
    -(num ./ (den'  .+ eps))

x = rand(ComplexF32, (513, 321, 1, 2));
y = rand(ComplexF32, (513, 321, 1, 2));
ŷ = rand(ComplexF32, (513, 321, 1, 2));

g = gradient(wsdrLoss, x, ŷ, y)

The code above accepts 3 arguments : the original input, predicted output and the reference output.
Each argument is composed of 2 samples, and each x has 513x321x1x2 dimensions. And each sample inside x as to be converted first into a matrix then the istft function will convert each to a vector.

If I understand correctly Zygote does not like variable mutations hence it complains about the .|> syntax.

For the second try, I tried to give only one sample at a time. To do that I changed the above code to :

using Zygote 
using LinearAlgebra: norm
using Statistics: mean 

function wsdrLoss2(x, ŷ, y, eps=1e-8)

    w, h, c, n =  size(y)

    x = reshape(x, w, h)  |> imag |> Iterators.flatten |> collect
    ŷ = reshape(ŷ, w, h)  |> imag |> Iterators.flatten |> collect
    y = reshape(y, w, h)  |> imag |> Iterators.flatten |> collect
    z = x - y
    ẑ = x - ŷ
    nd  = sum(y .^2)
    dom = sum(z .^2)
    aux = nd / (nd + dom + eps)
    wSDR = aux * sdr(ŷ, y) + (1 - aux) * sdr(ẑ, z)


function sdr(ypred, ygold; eps=1e-8)
    num = sum(ygold .* ypred)
    den = norm(ygold, 2) .* norm(ypred, 2)
    -(num / (den  + eps))

x = rand(ComplexF32, (513, 321, 1, 1)); # reducing sample size from 2 to 1.
y = rand(ComplexF32, (513, 321, 1, 1));
ŷ = rand(ComplexF32, (513, 321, 1, 1));

g = gradient(wsdrLoss2, x , ŷ , y )

But this time I get another error that is :

ERROR: MethodError: no method matching reshape(::IRTools.Inner.Undefined, ::Int64, ::Int64)
Closest candidates are:
  reshape(::FillArrays.AbstractFill, ::Union{Colon, Int64}...) at /opt/.julia/packages/FillArrays/NjFh2/src/FillArrays.jl:206
  reshape(::OffsetArrays.OffsetArray, ::Union{Colon, Int64}...) at /opt/.julia/packages/OffsetArrays/ExQCD/src/OffsetArrays.jl:240
  reshape(::AbstractArray, ::Int64...) at reshapedarray.jl:116
 [1] adjoint at /opt/.julia/packages/Zygote/xBjHw/src/lib/array.jl:93 [inlined]
 [2] _pullback at /opt/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [3] wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:161 [inlined]
 [4] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}, ::Float64) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [5] wsdrLoss at /home/kg/Projects/speech-enh/src/network.jl:153 [inlined]
 [6] _pullback(::Zygote.Context, ::typeof(wsdrLoss), ::Array{Complex{Float32},4}, ::Array{Complex{Float64},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [7] loss at ./REPL[29]:1 [inlined]
 [8] _pullback(::Zygote.Context, ::typeof(loss), ::Array{Complex{Float32},4}, ::Array{Complex{Float32},4}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [9] #5 at ./REPL[31]:5 [inlined]
 [10] _pullback(::Zygote.Context, ::var"#5#6"{typeof(loss),Array{Complex{Float32},4},Array{Complex{Float32},4}}) at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface2.jl:0
 [11] pullback at /opt/.julia/packages/Zygote/xBjHw/src/compiler/interface.jl:172 [inlined]
 [12] my_custom_train!(::typeof(loss), ::UNet, ::DataLoaders.BufferGetObsParallel{Tuple{Array{Complex{Float32},4},Array{Complex{Float32},4}},DataLoaders.BatchViewCollated{Tuple{Data,Data}}}, ::ADAM) at ./REPL[31]:5
 [13] top-level scope at REPL[32]:1

Could anyone help please ? I am doing sth wrong but what ?


Do you have to use slices? What you have written could be done with operations like

    nd = vec(sum(val -> imag(val)^2, y; dims=(1,2,3)))
    dom = vec(sum(@. (imag(x) - imag(y))^2; dims=(1,2,3)))

which are likely to be both faster and easier for Zygote to digest. Can this istft function be made to accept a 3-array, not just a matrix?

Your second example works if you delete |> Iterators.flatten |> collect, which isn’t doing anything at all (except confusing Zygote) I think.


Thank you for the reply.

Probably istft can be converted to accept 3-array. So you suggest to convert the x,y and y^ directly and obtain new values inside new variable?

Didn’t know that inside sum -> operator could be used. I really appreciate :slight_smile: Is there any resource where I can learn this kind of patterns ?


The sum(f, x) pattern may be a bit of a waste, as IIRC Zygote turns it back into sum(f.(x)) (whereas normally it would save allocating f.(x)).

I think so? I mean I suggest working with “solid” arrays as much as possible, rather than making slices. So instead of (or in addition to) istft(x::AbstractMatrix) you define istft(xs::AbstractArray{T,3}) with every step inside this function keeping track of one more dimension.

Thank you