Sparsity of SimpleChains Errors

I’ve been playing around with SimpleChains.jl for potential use in a larger sparse model, but I’m having some issues figuring out how to obtain a sparsity pattern. After some testing with both SparseConnectivityTracer.jl and Symbolics.jl, I’ve figured out that my problem is probably with the SimpleChain itself, but I’m not sure if I’m doing things wrong or missing something obvious.

I had no similar issues with Lux.jl, so I was surprised my code didn’t just work. Is there a way to obtain the sparsity pattern through a SimpleChain or is this incompatibility part of why SimpleChains is so fast in the first place?

Thanks!

Example code is below, along with the errors I got.

using SimpleChains, SparseConnectivityTracer, Symbolics

#Build simple chain
chain = SimpleChain(3, TurboDense(x->x^3, 1))
params = SimpleChains.init_params(chain)

#Test that it works (it does)
chain(rand(3), params)

#SparseConnectivityTracer sparsity detection
detector = TracerLocalSparsityDetector()
#First error here
SparseConnectivityTracer.jacobian_sparsity(x->chain(x, params), rand(3), detector)

#Build inplace function for Symbolics.jl
function inplacechain!(y, x)
    y.= chain(x, params)
end

#Compare results of inplace chain to make sure it's the same (it is)
a = ones(1)
inplacechain!(a, [1.0, 2, 3])
b = chain([1.0, 2, 3], params)
a==b

#Set up Symbolics.jl sparsity
output = Vector{Float64}(undef, 1)
input = Vector{Float64}(undef,3)

#Second error here
Symbolics.jacobian_sparsity(inplacechain!,output, input)

First Error:

ERROR: MethodError: no method matching dense!(::var"#11#12", ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::Static.True, ::Static.False)
The function `dense!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  dense!(::F, ::StrideArraysCore.PtrArray{D}, ::AbstractMatrix, ::StrideArraysCore.PtrArray, ::BT, ::FF) where {F, BT<:Static.StaticBool, FF, T, P, D<:(ForwardDiff.Dual{<:Any, T, P})}
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\forwarddiff_matmul.jl:585
  dense!(::Any, ::Any, ::AbstractMatrix, ::AbstractVector, ::AbstractMatrix, ::Any)
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:1035
  dense!(::F, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractMatrix{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::Static.True, ::Static.False) where F
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:207
  ...

Stacktrace:
  [1] TurboDense
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:184 [inlined]
  [2] __chain
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:330 [inlined]
  [3] _chain
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:407 [inlined]
  [4] SArrayOutput
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:153 [inlined]
  [5] with_stack_memory
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\memory.jl:25 [inlined]
  [6] (::SimpleChain{…})(arg::Vector{…}, params::StrideArraysCore.StaticStrideArray{…})
    @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:182
  [7] (::var"#13#14")(x::Vector{SparseConnectivityTracer.Dual{Float64, SparseConnectivityTracer.GradientTracer{…}}})
    @ Main .\Untitled-1:13
  [8] trace_function(::Type{SparseConnectivityTracer.Dual{…}}, f::var"#13#14", x::Vector{Float64})
    @ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\trace_functions.jl:48
  [9] _local_jacobian_sparsity(f::Function, x::Vector{Float64}, ::Type{SparseConnectivityTracer.GradientTracer{…}})
    @ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\trace_functions.jl:96
 [10] jacobian_sparsity(f::Function, x::Vector{…}, ::TracerLocalSparsityDetector{…})
    @ SparseConnectivityTracer C:\Users\johnb\.julia\packages\SparseConnectivityTracer\XQsYR\src\adtypes_interface.jl:149
 [11] top-level scope
    @ Untitled-1:13

Second Error

ERROR: MethodError: no method matching dense!(::var"#11#12", ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::StrideArraysCore.PtrArray{…}, ::Static.True, ::Static.False)
The function `dense!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  dense!(::F, ::StrideArraysCore.PtrArray{D}, ::AbstractMatrix, ::StrideArraysCore.PtrArray, ::BT, ::FF) where {F, BT<:Static.StaticBool, FF, T, P, D<:(ForwardDiff.Dual{<:Any, T, P})}
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\forwarddiff_matmul.jl:585
  dense!(::Any, ::Any, ::AbstractMatrix, ::AbstractVector, ::AbstractMatrix, ::Any)
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:1035
  dense!(::F, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractMatrix{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::AbstractVecOrMat{<:Union{Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8}}, ::Static.True, ::Static.False) where F
   @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:207
  ...

Stacktrace:
  [1] TurboDense
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\dense.jl:184 [inlined]
  [2] __chain
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:330 [inlined]
  [3] _chain
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:407 [inlined]
  [4] SArrayOutput
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:153 [inlined]
  [5] with_stack_memory
    @ C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\memory.jl:25 [inlined]
  [6] (::SimpleChain{…})(arg::Vector{…}, params::StrideArraysCore.StaticStrideArray{…})
    @ SimpleChains C:\Users\johnb\.julia\packages\SimpleChains\mSbJT\src\simple_chain.jl:182
  [7] inplacechain!(y::Vector{Num}, x::Vector{Num})
    @ Main .\Untitled-1:17
  [8] jacobian_sparsity(::Function, ::Vector{Float64}, ::Vector{Float64}; kwargs::@Kwargs{})
    @ Symbolics C:\Users\johnb\.julia\packages\Symbolics\zyJQ0\src\diff.jl:787
  [9] jacobian_sparsity(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Symbolics C:\Users\johnb\.julia\packages\Symbolics\zyJQ0\src\diff.jl:782
 [10] top-level scope
    @ Untitled-1:31

Can you explain why you’re interested in the sparsity pattern here?

In the rest of my project, I’ve got the ML component hooked up to a larger model which has sparsity. My actual problem is obtaining the sparsity pattern for the model as a whole, but the rest of the model works fine with sparsity detection, including when I replace the simplechain component with a Lux model or a manually written function. I get the same type of error from my bigger model as I do from just the simplechain, so I figured reducing it to just the chain would be the best way to ask for advice.

Maybe you can perform sparsity detection separately with the equivalent Lux model, and then provide the known sparsity pattern?

Yeah, I’ve considered doing that or just writing a function to generate the sparsity pattern more manually, but it would have been convenient if it would just work as is. Sounds like it might just not be possible now.

You might also be interested in the DenseSparsityDetector offered by DifferentiationInterface.jl. It sucks and has no formal guarantees but it always runs.

I hadn’t seen that one, I’ll give it a shot. Thanks for the recommendation!

Edit: I see what you mean by no formal guarantees. I’ll probably go with the Lux or manual configuration options for now, but it’s cool that exists!

1 Like