When using a matrix function generated using Symbolics.jl variables, I get the error Mutating arrays is not supported
when trying to autodifferentiate through the matrix function with either Zygote.gradient
or Zygote.jacobian
. Minimum working example is below, and full error stacktrace is at the bottom of this post.
MWE
using Symbolics, Zygote
@variables x, y
M = [x y]
M_expr = build_function(M, [x,y])
M_func = eval(M_expr[1])
Zygote.gradient(z->sum(M_func(z)), [1,2]) # Error; full stacktrace at bottom of post
Julia version is 1.6.0.
Symbolics.jl is v0.1.12
Zygote.jl is v0.6.7
Is it possible to create matrix functions using Symbolics.jl that do not use mutation? Otherwise I’d like to know how to properly use Zygote.jl with matrix functions generated by Symbolics.jl.
Additional Details:
If I perform the sum()
function on the symbolic matrix before calling build_function
, then Zygote can take the gradient:
M_sum = sum(M)
M_sum_expr = build_function(M_sum, [x,y])
M_sum_func = eval(M_sum_expr)
Zygote.gradient(M_sum_func, [1,2]) # This works
I don’t think I have this option however–I’m trying to autodifferentiate through a DifferentialEquations.jl ODE of the form \dot{x} = A(p)x + Bu, y = C(p)x, where p contains the parameters I’m trying to differentiate with respect to. As far as I’m aware, the functions A(), C() can’t be passed into an ODE solver in their symbolic form, but I could be wrong.
Full Stacktrace
ERROR: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#399#400")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/lib/array.jl:58
[3] (::Zygote.var"#2253#back#401"{Zygote.var"#399#400"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ ~/.julia/packages/SymbolicUtils/aNxjZ/src/code.jl:383 [inlined]
[5] (::Zygote.var"#178#179"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(_create_array))})(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/lib/lib.jl:194
[6] #1686#back
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[7] Pullback
@ ~/.julia/packages/SymbolicUtils/aNxjZ/src/code.jl:394 [inlined]
[8] Pullback
@ ~/.julia/packages/SymbolicUtils/aNxjZ/src/code.jl:371 [inlined]
[9] (::typeof(∂(#11)))(Δ::FillArrays.Fill{Int64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[38]:1 [inlined]
[11] (::typeof(∂(#13)))(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[12] (::Zygote.var"#41#42"{typeof(∂(#13))})(Δ::Int64)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:41
[13] gradient(f::Function, args::Vector{Int64})
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:59
[14] top-level scope
@ REPL[38]:1