MWE (or as minimal as I can think to make it):
using Flux
import ForwardDiff, Zygote
g = Dense(2, 1, softplus)
loss(m, x) = sum(abs2, Zygote.hessian(x->sum(m(x)), x))
gradient(g, rand(2)) do m, x
loss(m, x)
end
Both Zygote.hessian
and ForwardDiff.hessian
throw out the following error
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#399#400")(::Nothing) at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/array.jl:58
[3] (::Zygote.var"#2265#back#401"{Zygote.var"#399#400"})(::Nothing) at /home/arun/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] forward_jacobian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/forward.jl:25 [inlined]
[5] (::typeof(∂(forward_jacobian)))(::Tuple{Nothing,Array{Float64,2}}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
[6] forward_jacobian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/forward.jl:38 [inlined]
[7] (::typeof(∂(forward_jacobian)))(::Tuple{Nothing,Array{Float64,2}}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
[8] hessian_dual at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/grad.jl:74 [inlined]
[9] (::typeof(∂(hessian_dual)))(::Array{Float64,2}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
[10] hessian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/grad.jl:72 [inlined]
[11] loss at /tmp/test.jl:5 [inlined]
[12] (::typeof(∂(loss)))(::Float64) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
[13] #67 at /tmp/test.jl:7 [inlined]
[14] (::Zygote.var"#41#42"{typeof(∂(#67))})(::Float64) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:41
[15] gradient(::Function, ::Dense{typeof(softplus),Array{Float32,2},Array{Float32,1}}, ::Vararg{Any,N} where N) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:59
[16] top-level scope at /tmp/test.jl:6
[17] include(::String) at ./client.jl:457
[18] top-level scope at REPL[9]:1
in expression starting at /tmp/test.jl:6
Any ideas how I can go about solving this issue? Thanks!