(yet another) Zygote: mutating arrays is not supported

Hi all, thanks for the great work on Flux and Zygote!

Quick summary:

I’m training a simple (feedforward, relu) neural network on MNIST.

  • If I use Flux.Losses.mse as a loss function, my code works
  • If I use Flux.Losses.logitcrossentropy, it crashes

Thus I’m wondering whether this is a bug in Flux/Zygote, rather than in my code?

Stacktrace

(when using logitcrossentropy as my loss function)

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#407#408")(#unused#::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/lib/array.jl:61
  [3] (::Zygote.var"#2269#back#409"{Zygote.var"#407#408"})(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] Pullback
    @ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:123 [inlined]
  [8] (::typeof(∂(#∇logsoftmax!#60)))(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:123 [inlined]
 [10] (::typeof(∂(∇logsoftmax!##kw)))(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [11] Pullback
    @ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:113 [inlined]
 [12] (::typeof(∂(#∇logsoftmax#56)))(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:113 [inlined]
 [14] (::typeof(∂(∇logsoftmax##kw)))(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:128 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Zygote/i1R8y/src/compiler/chainrules.jl:77 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/.julia/packages/Zygote/i1R8y/src/compiler/chainrules.jl:103 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [21] Pullback
    @ ~/.julia/packages/Flux/0c9kI/src/losses/functions.jl:244 [inlined]
 [22] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/.julia/packages/Flux/0c9kI/src/losses/functions.jl:244 [inlined]
 [24] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float32}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/.julia/dev/GenError/scripts/size_comparison.jl:38 [inlined]
 [26] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [27] Pullback
    @ ~/.julia/dev/GenError/src/updateInfo.jl:126 [inlined]
 [28] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [29] Pullback
    @ ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:255 [inlined]
 [30] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [31] Pullback
    @ ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:59 [inlined]
 [32] (::typeof(∂(gradient)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [33] Pullback
    @ ~/.julia/dev/GenError/src/updateInfo.jl:125 [inlined]
 [34] (::typeof(∂(sumgrad)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [35] Pullback
    @ ~/.julia/dev/GenError/src/updateInfo.jl:133 [inlined]
 [36] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
 [37] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:255
 [38] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:59
 [39] (::GenError.var"#curvature#29"{GenError.var"#sumgrad#26"})(ps::Zygote.Params, us::Vector{Array{Float32, N} where N}, lossf::Function, data::Tuple{Matrix{Float32}, Matrix{Float32}})

MWE

I’ll make one on request, but I’m hoping there is a simple answer to this not requiring an MWE. The code which generates the error is fairly understandable, I hope?..

function _update(u::UpdateCurvature, store, datum)    
    function sumgrad(ps, us, lossf, data)
        gs = gradient(ps) do 
            lossf(data[1], data[2])
        end
        return sum(sum(g .* u) for (u, g) in zip(us, gs))
    end

    function curvature(ps, us, lossf, data)
        gs = gradient(ps) do 
            sumgrad(ps, us, lossf, data)
        end
        return sum(sum(g .* u) for (u, g) in zip(us, gs))
    end
    data = which_data(u, store, datum)
    return curvature(store[:params], store[:update], store[:lossf], data)
end

Here, store[:params] is params(my_model), and store[:update] is an array of arrays of the same size as the weights of the model (it’s the change in weights over a timestep).

Help greatly appreciated, thanks!

The long and short of it is that (logit)crossentropy can’t be twice differentiated with Zygote and mse can. This should be resolved once we have a new AD system (no ETA, unfortunately), but in the meantime you can try the approach in calculating 2nd order differentials of a function containing a NN w.r.t. to parameters · Issue #911 · FluxML/Zygote.jl · GitHub.

2 Likes

I see, thanks!

The approach you linked to seems to calculate the entire Hessian matrix, which will be huge. I instead want to calculate a Hessian-vector product, which ideally can be done by differentating a scalar valued function involving the gradient (see below), without ever storing a Hessian.

I tried to do so using the style of coding you referenced in Issue #911, but I still get the mutating arrays issue. Is the only workaround to build the full Hessian? For reference, here is an MWE of what I tried, hope there is a way to make it work…

student = Chain(
 	        Dense(784, 128, relu),
            (Dense(128, 10, relu))
            )

w, re = Flux.destructure(student)        # length(v) == sum(length, params(model))
lloss(re,x,y) = Flux.Losses.logitcrossentropy(re(x), y)
# lloss(re,x,y) = Flux.Losses.mse(re(x), y)

x = randn(Float32,784,50)
y = randn(Float32, 10, 50)
w, re = Flux.destructure(student)
u = deepcopy(w) ## arbitrary vector with same size as the weights. 


function all(u,w,x,y)
    function sumgu(u, w, x, y)
        return dot(Zygote.gradient(v -> lloss(re(v),x,y), w)[1], u)
    end

    function curvature(u,w,x,y)
        return Zygote.gradient(v -> sumgu(u,v,x,y), w)
    end

    return curvature(u,w,x,y)
end

Just use the Hessian-vector product for Zygote as implemented in GalacticOptim.jl. I’ve posted it here upwards of dozens of times so a search should find it, or just look at the GalacticOptim.jl source for the ForwardDiff.Dual adjoints + the hv definition.

Where this differs from Zygote.hessian is that it calls Zygote.gradient twice. Simply replacing the second one with ForwardDiff.gradient may well work.

Edit – now I tried, it matches method of next message. But is much slower.

julia> function all(u,w,x,y)
           function sumgu(u, w, x, y)
               return dot(Zygote.gradient(v -> lloss(re(v),x,y), w)[1], u)
           end
           # The only change is Zygote -> ForwardDiff here:
           ForwardDiff.gradient(v -> sumgu(u,v,x,y), w)
       end;

julia> Hu ≈ all(u,w,x,y)  # with result from below
true

Thanks!

Maybe I’m searching for the wrong thing, but I couldn’t find a relevant discourse thread.

However, as you suggested, I took the GalacticOptim.jl source code and copied it. It works! (I error checked the result). So below is a working example of Hessian vector products with Zygote and ForwardDiff for my MWE:

using ForwardDiff
using Flux, Zygote

student = Chain(
 	        Dense(784, 128, relu),
            (Dense(128, 10, relu))
            )

w, re = Flux.destructure(student)
lloss(student,x,y) = Flux.Losses.logitcrossentropy(student(x), y)


x = randn(Float32, 784, 50)
y = randn(Float32, 10, 50)
w, re = Flux.destructure(student)
u = deepcopy(w) .+ 1. ## arbitrary vector with same size as the weights. 


wloss(w) = lloss(re(w), x, y)
_w = ForwardDiff.Dual.(w, u)
rr = Zygote.gradient(wloss, _w)[1]

""" Hu is the product of the Hessian of the loss function with u"""
Hu = getindex.(ForwardDiff.partials.(rr), 1)

That said, ForwardDiff.partials doesn’t have any documentation, and so I’m not wholly confident about why my copied code works…

Thanks!

I tried just putting ForwardDiff on top as you suggested, but it didn’t work. I just posted what did work doing ForwardDiff, which is an adaptation of the GalacticOptim.jl source code.