Zygote: product with a (constant) sparse boolean

How can I make products with (constant) sparse boolean matrices and vectors work with Zygote? Example:

julia> using Zygote, SparseArrays, LinearAlgebra

julia> gradient(v -> dot(sparse([true;false]),v), [1.;0.])
ERROR: MethodError: no method matching zero(::Nothing)

Closest candidates are:
  zero(::Union{Type{P}, P}) where P<:Dates.Period
   @ Dates ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/Dates/src/periods.jl:51
   @ Base irrationals.jl:151
  zero(::FillArrays.Ones{T, N}) where {T, N}
   @ FillArrays ~/.julia/packages/FillArrays/yjfkJ/src/FillArrays.jl:572

  [1] iszero(x::Nothing)
    @ Base ./number.jl:42
  [2] _iszero(x::Nothing)
    @ SparseArrays ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/SparseArrays/src/SparseArrays.jl:37
  [3] _noshapecheck_map(::typeof(Zygote.wrap_chainrules_output), ::SparseVector{ChainRulesCore.NoTangent, Int64})
    @ SparseArrays.HigherOrderFns ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/SparseArrays/src/higherorderfns.jl:181
  [4] map
    @ ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/SparseArrays/src/higherorderfns.jl:152 [inlined]
  [5] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:127 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:110 [inlined]
  [7] map
    @ ./tuple.jl:275 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:111 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/SuKWp/src/compiler/chainrules.jl:211 [inlined]
 [10] Pullback
    @ ./REPL[4]:1 [inlined]
 [11] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#9#10", Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#dot_pullback#1947"{SparseVector{Bool, Int64}, Vector{Float64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{SparseVector, NamedTuple{(:element, :nzind, :axes), Tuple{ChainRulesCore.ProjectTo{ChainRulesCore.NoTangent, NamedTuple{(), Tuple{}}}, Vector{Int64}, Tuple{Base.OneTo{Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{ChainRulesCore.NoTangent, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{ChainRulesCore.NoTangent, NamedTuple{(), Tuple{}}}}, Tuple{Tuple{}, Tuple{}}, Val{1}}}, Zygote.Pullback{Tuple{typeof(sparse), Vector{Bool}}, Tuple{Zygote.Pullback{Tuple{typeof(sparsevec), Vector{Bool}}, Tuple{Zygote.ZBack{ChainRules.var"#sparse_pullback#2153"}}}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
 [12] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
 [13] top-level scope
    @ REPL[4]:1

I did not expect issues because this is just a constant used as a bit mask. The same product works if I change the matrix type to Int64.

julia> gradient(v -> sum(v[sparse([true;false]).nzind]), [1.;0.])
([1.0, 0.0],)
1 Like

Thanks! And for a matrix? For instance:

P = sparse([true false; false true])
gradient(v -> v'*(P*v), [1.;0.])

I don’t see a way to generalize that trick unfortunately.

You could always write an rrule, since this kind of expression is trivial to differentiate by hand.

For example, we can define a QuadraticForm wrapper around the matrix:

using LinearAlgebra
import ChainRulesCore
using ChainRulesCore: ProjectTo, @not_implemented

struct QuadraticForm{T<:LinearAlgebra.HermOrSym{<:Real}}
(q::QuadraticForm)(x::Vector) = dot(x, q.A, x)

function ChainRulesCore.rrule(q::QuadraticForm, x::AbstractVector{<:Real})
    project_x = ProjectTo(x)
    Ax = q.A*x
    y = x'Ax
    pullback(∂y) = @not_implemented("A assumed constant"), project_x((2∂y)*Ax)
    return y, pullback

and then you have:

julia> q = QuadraticForm(Symmetric(sparse([true false; false true])))
QuadraticForm{Symmetric{Bool, SparseMatrixCSC{Bool, Int64}}}(Bool[1 0; 0 1])

julia> gradient(q, [1,0])
([2.0, 0.0],)

(Note that the above works for any A, whether it is sparse or boolean or not.)

(It does seem like there should be a ChainRule for dot(x, A, y) built-in… however, currently there is relatively little support for sparse matrices in Zygote and ChainRules, and I end up writing my own rrule much of the time.)

1 Like

There is a rule here, but it’s not clear to me why gradient(v -> dot(v, P, v), [1.;0.]) fails. An issue on ChainRules.jl about this would be useful.

Yes. What special code there is was a quick prototype before CR v1.0, and could use attention from someone who needs it.

It looks like the rule is also computing the derivative with respect to P, which fails because the derivative is not sparse and so it can’t be “projected” back onto the type of P.

This is a general difficulty with writing these sorts of rules. To make it fully general you have to support derivatives with respect to everything, but this may be inefficient (or fail completely, as it does here). My understanding is that Zygote relies on the compiler’s dead-code elimination to remedy the inefficiency, but that means it can fail when an unneeded portion of the rule is not sufficiently general. Enzyme instead asks you to provide a combinatorial explosion of rules (depending on which arguments are constant), but that makes life harder in a different way.

But I’m not sure if any AD system, in any language, does a good job of differentiating through sparse matrix construction? The farther you stray from neural networks, the more trouble AD seems to run into.


Enzyme supports the various combinations of activities without a combinatorial explosion of work to write rules. This was previously an open question during design prior to release, but has since been resolved (though obviously further improvements welcome, and we want to make a slightly easier to use version as well that registers with the more powerful API below)!

In essence in the single rule code you can write something like

if !(a isa Const)
   compute its derivative
if !(b isa Const)
   compute its derivative

See our docs for more info (Custom rules · Enzyme.jl).

Moreover, if you write a rule that only supports a specific (or perhaps subset) of activity states, that’s also fine. If a different activity set is required you’ll get a runtime error specifying the rule activity that wasn’t implemented, and requesting it.

1 Like

Yes it computes the derivative for P. Delayed by @thunk but Zygote at present ignores that.

It should explicitly (if inefficiently) always project dense results back to sparse, e.g. gradient(x -> sum(Array(x).^2), spdiagm([1.0, 2.0])). I don’t know whether that’s what fails here.

It should also notice that since P is a boolean array, it never has a derivative, as in e.g. gradient(x -> sum(Array(P).^2), P). Perhaps this check ought to happen earlier.