Hi all, thanks for the great work on Flux and Zygote!
Quick summary:
I’m training a simple (feedforward, relu) neural network on MNIST.
- If I use
Flux.Losses.mse
as a loss function, my code works - If I use
Flux.Losses.logitcrossentropy
, it crashes
Thus I’m wondering whether this is a bug in Flux/Zygote, rather than in my code?
Stacktrace
(when using logitcrossentropy as my loss function)
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#407#408")(#unused#::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/lib/array.jl:61
[3] (::Zygote.var"#2269#back#409"{Zygote.var"#407#408"})(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ./broadcast.jl:894 [inlined]
[5] Pullback
@ ./broadcast.jl:891 [inlined]
[6] Pullback
@ ./broadcast.jl:887 [inlined]
[7] Pullback
@ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:123 [inlined]
[8] (::typeof(∂(#∇logsoftmax!#60)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:123 [inlined]
[10] (::typeof(∂(∇logsoftmax!##kw)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[11] Pullback
@ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:113 [inlined]
[12] (::typeof(∂(#∇logsoftmax#56)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:113 [inlined]
[14] (::typeof(∂(∇logsoftmax##kw)))(Δ::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/NNlib/LiXUT/src/softmax.jl:128 [inlined]
[16] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Zygote/i1R8y/src/compiler/chainrules.jl:77 [inlined]
[18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[19] Pullback
@ ~/.julia/packages/Zygote/i1R8y/src/compiler/chainrules.jl:103 [inlined]
[20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float32}})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[21] Pullback
@ ~/.julia/packages/Flux/0c9kI/src/losses/functions.jl:244 [inlined]
[22] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[23] Pullback
@ ~/.julia/packages/Flux/0c9kI/src/losses/functions.jl:244 [inlined]
[24] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Nothing, Matrix{Float32}, Nothing})
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[25] Pullback
@ ~/.julia/dev/GenError/scripts/size_comparison.jl:38 [inlined]
[26] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[27] Pullback
@ ~/.julia/dev/GenError/src/updateInfo.jl:126 [inlined]
[28] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[29] Pullback
@ ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:255 [inlined]
[30] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[31] Pullback
@ ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:59 [inlined]
[32] (::typeof(∂(gradient)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[33] Pullback
@ ~/.julia/dev/GenError/src/updateInfo.jl:125 [inlined]
[34] (::typeof(∂(sumgrad)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[35] Pullback
@ ~/.julia/dev/GenError/src/updateInfo.jl:133 [inlined]
[36] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface2.jl:0
[37] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(λ)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:255
[38] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/i1R8y/src/compiler/interface.jl:59
[39] (::GenError.var"#curvature#29"{GenError.var"#sumgrad#26"})(ps::Zygote.Params, us::Vector{Array{Float32, N} where N}, lossf::Function, data::Tuple{Matrix{Float32}, Matrix{Float32}})
MWE
I’ll make one on request, but I’m hoping there is a simple answer to this not requiring an MWE. The code which generates the error is fairly understandable, I hope?..
function _update(u::UpdateCurvature, store, datum)
function sumgrad(ps, us, lossf, data)
gs = gradient(ps) do
lossf(data[1], data[2])
end
return sum(sum(g .* u) for (u, g) in zip(us, gs))
end
function curvature(ps, us, lossf, data)
gs = gradient(ps) do
sumgrad(ps, us, lossf, data)
end
return sum(sum(g .* u) for (u, g) in zip(us, gs))
end
data = which_data(u, store, datum)
return curvature(store[:params], store[:update], store[:lossf], data)
end
Here, store[:params]
is params(my_model)
, and store[:update]
is an array of arrays of the same size as the weights of the model (it’s the change in weights over a timestep).
Help greatly appreciated, thanks!