Ran into this error running code similar to this MWE, and was hoping someone who knows more about the AD ecosystem could help point me where to file an issue–is this a missing ChainRule, or overly-strict typing in SparseArrays? (Also, any suggestions for workarounds are appreciated.) Thanks!
using Distributions
using PDMats
using SparseArrays
using Zygote
function f(u, p)
μ = u[1]
σ = exp(u[2])
Q = spdiagm(fill(σ^2, p.n))
d = MvNormalCanon(fill(μ, p.n), PDSparseMat(Q))
return -logpdf(d, p.x)
end
x = randn(100)
p = (x=x, n=length(x))
Zygote.hessian(u -> f(u, p), rand(2))
ERROR: TypeError: in Sparse, in Tv, expected Tv<:Union{Float64, ComplexF64}, got Type{ForwardDiff.Dual{Nothing, Float64, 2}}
Stacktrace:
[1] SparseArrays.CHOLMOD.Sparse(A::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64}, stype::Int64)
@ SparseArrays.CHOLMOD C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\solvers\cholmod.jl:790
[2] SparseArrays.CHOLMOD.Sparse(A::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
@ SparseArrays.CHOLMOD C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\solvers\cholmod.jl:796
[3] rrule
@ C:\Users\sam.urmy\.julia\packages\ChainRules\aKxNz\src\rulesets\SparseArrays\sparsematrix.jl:17 [inlined]
[4] rrule
@ C:\Users\sam.urmy\.julia\packages\ChainRulesCore\0t04l\src\rules.jl:134 [inlined]
[5] chain_rrule
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\chainrules.jl:223 [inlined]
[6] macro expansion
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:101 [inlined]
[7] _pullback
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:101 [inlined]
[8] _pullback
@ C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\solvers\cholmod.jl:1300 [inlined]
[9] _pullback(::Zygote.Context{false}, ::SparseArrays.CHOLMOD.var"##cholesky#10", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(LinearAlgebra.cholesky), ::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
[10] _pullback
@ C:\Users\sam.urmy\.julia\juliaup\julia-1.9.0+0.x64.w64.mingw32\share\julia\stdlib\v1.9\SparseArrays\src\solvers\cholmod.jl:1300 [inlined]
[11] _pullback
@ C:\Users\sam.urmy\.julia\packages\PDMats\CbBv1\src\pdsparsemat.jl:20 [inlined]
[12] _pullback(ctx::Zygote.Context{false}, f::Type{PDSparseMat}, args::SparseMatrixCSC{ForwardDiff.Dual{Nothing, Float64, 2}, Int64})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
[13] _pullback
@ .\Untitled-1:10 [inlined]
[14] _pullback(::Zygote.Context{false}, ::typeof(f), ::Vector{ForwardDiff.Dual{Nothing, Float64, 2}}, ::NamedTuple{(:x, :n), Tuple{Vector{Float64}, Int64}})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
[15] _pullback
@ .\Untitled-1:16 [inlined]
[16] _pullback(ctx::Zygote.Context{false}, f::var"#17#18", args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface2.jl:0
[17] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:44
[18] pullback
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:42 [inlined]
[19] gradient(f::Function, args::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\compiler\interface.jl:96
[20] (::Zygote.var"#121#122"{var"#17#18"})(x::Vector{ForwardDiff.Dual{Nothing, Float64, 2}})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:64
[21] forward_jacobian(f::Zygote.var"#121#122"{var"#17#18"}, x::Vector{Float64}, #unused#::Val{2})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:29
[22] forward_jacobian(f::Function, x::Vector{Float64}; chunk_threshold::Int64)
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:44
[23] forward_jacobian
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\forward.jl:42 [inlined]
[24] hessian_dual
@ C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:64 [inlined]
[25] hessian(f::Function, x::Vector{Float64})
@ Zygote C:\Users\sam.urmy\.julia\packages\Zygote\HTsWj\src\lib\grad.jl:62
[26] top-level scope
@ Untitled-1:16