I’ve been playing around with SimpleChains.jl for potential use in a larger sparse model, but I’m having some issues figuring out how to obtain a sparsity pattern. After some testing with both SparseConnectivityTracer.jl and Symbolics.jl, I’ve figured out that my problem is probably with the SimpleChain itself, but I’m not sure if I’m doing things wrong or missing something obvious.
I had no similar issues with Lux.jl, so I was surprised my code didn’t just work. Is there a way to obtain the sparsity pattern through a SimpleChain or is this incompatibility part of why SimpleChains is so fast in the first place?
Thanks!
Example code is below, along with the errors I got.
using SimpleChains, SparseConnectivityTracer, Symbolics
#Build simple chain
chain = SimpleChain(3, TurboDense(x->x^3, 1))
params = SimpleChains.init_params(chain)
#Test that it works (it does)
chain(rand(3), params)
#SparseConnectivityTracer sparsity detection
detector = TracerLocalSparsityDetector()
#First error here
SparseConnectivityTracer.jacobian_sparsity(x->chain(x, params), rand(3), detector)
#Build inplace function for Symbolics.jl
function inplacechain!(y, x)
y.= chain(x, params)
end
#Compare results of inplace chain to make sure it's the same (it is)
a = ones(1)
inplacechain!(a, [1.0, 2, 3])
b = chain([1.0, 2, 3], params)
a==b
#Set up Symbolics.jl sparsity
output = Vector{Float64}(undef, 1)
input = Vector{Float64}(undef,3)
#Second error here
Symbolics.jacobian_sparsity(inplacechain!,output, input)
First Error:
ERROR: MethodError: no method matching dense!(::var"#11#12", ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::Static.True, ::Static.False)
The function `dense!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
dense!(::F, ::StrideArraysCore.PtrArray{D}, ::AbstractMatrix, ::StrideArraysCore.PtrArray, ::BT, ::FF) where {F, BT<:Static.StaticBool, FF, T, P, D<:(ForwardDiff.Dual{<:Any, T, P})}
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\forwarddiff_matmul.jl:585
dense!(::Any, ::Any, ::AbstractMatrix, ::AbstractVector, ::AbstractMatrix, ::Any)
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:1035
dense!(::F, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractMatrix{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::Static.True, ::Static.False) where F
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:207
...
Stacktrace:
[1] TurboDense
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:184 [inlined]
[2] __chain
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:330 [inlined]
[3] _chain
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:407 [inlined]
[4] SArrayOutput
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:153 [inlined]
[5] with_stack_memory
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\memory.jl:25 [inlined]
[6] (::SimpleChain{…})(arg::Vector{…}, params::StrideArraysCore.StaticStrideArray{…})
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:182
[7] (::var"#13#14")(x::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{…}}})
@ Main .\Untitled-1:13
[8] trace_function(::Type{SparseConnectivityTracer.Dual{…}}, f::var"#13#14", x::Vector{Float64})
@ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\trace_functions.jl:48
[9] _local_jacobian_sparsity(f::Function, x::Vector{Float64}, ::Type{SparseConnectivityTracer.GradientTracer{…}})
@ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\trace_functions.jl:96
[10] jacobian_sparsity(f::Function, x::Vector{…}, ::TracerLocalSparsityDetector{…})
@ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\adtypes_interface.jl:149
[11] top-level scope
@ Untitled-1:13
Second Error
ERROR: MethodError: no method matching dense!(::var"#11#12", ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::Static.True, ::Static.False)
The function `dense!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
dense!(::F, ::StrideArraysCore.PtrArray{D}, ::AbstractMatrix, ::StrideArraysCore.PtrArray, ::BT, ::FF) where {F, BT<:Static.StaticBool, FF, T, P, D<:(ForwardDiff.Dual{<:Any, T, P})}
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\forwarddiff_matmul.jl:585
dense!(::Any, ::Any, ::AbstractMatrix, ::AbstractVector, ::AbstractMatrix, ::Any)
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:1035
dense!(::F, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractMatrix{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::Static.True, ::Static.False) where F
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:207
...
Stacktrace:
[1] TurboDense
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:184 [inlined]
[2] __chain
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:330 [inlined]
[3] _chain
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:407 [inlined]
[4] SArrayOutput
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:153 [inlined]
[5] with_stack_memory
@ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\memory.jl:25 [inlined]
[6] (::SimpleChain{…})(arg::Vector{…}, params::StrideArraysCore.StaticStrideArray{…})
@ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:182
[7] inplacechain!(y::Vector{Num}, x::Vector{Num})
@ Main .\Untitled-1:17
[8] jacobian_sparsity(::Function, ::Vector{Float64}, ::Vector{Float64}; kwargs::@Kwargs{})
@ Symbolics C:\Users\johnb\.julia\packages\Symbolics\zyJQ0\src\diff.jl:787
[9] jacobian_sparsity(::Function, ::Vector{Float64}, ::Vector{Float64})
@ Symbolics C:\Users\johnb\.julia\packages\Symbolics\zyJQ0\src\diff.jl:782
[10] top-level scope
@ Untitled-1:31