Problems with AD inference on wrapper function

I’m trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now Zygote.gradient is inferring Any as a return value and I can’t figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It’s only the outermost wrapper function that fails to infer.

Context – for reference you can get the same version of the package I’m debugging with:

add https://github.com/SymbolicML/DynamicExpressions.jl#c9eaedf63e36a227db702b4ea3257938892447d5

Basically I have this recursive tree structure Node{T}. I don’t want Zygote to try to walk through the whole tree, so instead, I have this custom NodeTangent type (in src/ChainRules.jl)

struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
    tree::N
    gradient::A
end
Base.:+(a::NodeTangent, b::NodeTangent) = NodeTangent(a.tree, a.gradient + b.gradient)
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()

I then have a chain rule for evaluation which returns this NodeTangent, defined as follows:

function CRC.rrule(
    ::typeof(eval_tree_array),
    tree::AbstractExpressionNode,
    X::AbstractMatrix,
    operators::OperatorEnum;
    kws...,
)
    primal, complete = eval_tree_array(tree, X, operators; kws...)

    if !complete
        primal .= NaN
    end

    return (primal, complete), EvalPullback(tree, X, operators)
end

# Wrap in struct rather than closure to ensure variables are boxed
struct EvalPullback{N,A,O} <: Function
    tree::N
    X::A
    operators::O
end

# TODO: Preferable to use the primal in the pullback somehow
function (e::EvalPullback)((dY, _))
    _, dX_constants_dY, complete = eval_grad_tree_array(
        e.tree, e.X, e.operators; variable=Val(:both)
    )

    if !complete
        dX_constants_dY .= NaN
    end

    nfeatures = size(e.X, 1)
    dX_dY = @view dX_constants_dY[1:nfeatures, :]
    dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]

    dtree = NodeTangent(
        e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
    )

    dX = dX_dY .* reshape(dY, 1, length(dY))

    return (NoTangent(), dtree, dX, NoTangent())
end

This actually works fine. I can get derivatives that are correct and inference seems good:

using DynamicExpressions
using Zygote

const operators = OperatorEnum(;
    binary_operators=[+, -, *],
    unary_operators=[cos],
)
x1 = Node{Float64}(feature=1)
x2 = Node{Float64}(feature=2)

tree = x1 * cos(x2 - 3.2)
julia> Test.@inferred Zygote.gradient(
           t -> eval_tree_array(t, ones(2, 1), operators)[1][1],
           tree
       )
(NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]),)

and it returns a NodeTangent which prevents Zygote from walking the tree.

However, when I then try to use my new Expression type, which is nothing but a Node{T} plus a named tuple of operators and variable names:

struct Expression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
    tree::N
    metadata::Metadata{D}
end

I no longer get this successful inference. Here is the wrapper method of the evaluation:

function eval_tree_array(
    ex::AbstractExpression,
    cX::AbstractMatrix,
    operators::Union{AbstractOperatorEnum,Nothing}=nothing;
    kws...,
)
    return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end

So it basically just unpacks ex -> ex.tree and ex -> ex.metadata.operators.

Now, say that I try to take the gradient of this instead. Unlike the internal eval_tree_array call, this one I do not define a custom chain rule for (since the wrapper call is simple).

julia> ex = Expression(tree; operators, variable_names=["x1", "x2"])
x1 * cos(x2 - 3.2)

julia> Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
((tree = NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]), metadata = nothing),)

julia> Test.@inferred Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
ERROR: return type Tuple{@NamedTuple{tree::NodeTangent{Float64, Node{Float64}, Vector{Float64}}, metadata::Nothing}} does not match inferred return type Tuple{Any}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ REPL[25]:1

Even though all eval_tree_array(::AbstractExpression is doing is some getproperty calls before passing to a call – which I know works – inference on this wrapper fails.

Questions:

  1. Any guesses as to what the issue is from?
  2. Do I need to define a custom tangent type for Expression, and what’s the actual interface for AbstractTangent?
    • For the record, I did try creating a zero_tangent(::Expression), but this didn’t seem to fix the issue. Maybe there’s some other function I need to define?
    • Perhaps I need to declare NodeTangent for some other function symbols so that Zygote doesn’t try descending at the outermost call? And if so, what methods need to be implemented?
  3. (General) How does one go about debugging type inference issues in Zygote, when type inference on the primal is fine? I can’t seem to use Cthulhu.jl effectively though perhaps I am descending the wrong tree.

(cc @gdalle in case you have any ideas)

2 Likes

Should I post a GitHub issue for this instead? It’s so niche that I wonder if I should ask a specific person for help.

I saw your post but I have no clue, sorry

The answer is relatively straightforward but does require some background on how keyword function dispatch works under the hood, Julia Functions · The Julia Language. As seen from the example, the auto-generated “keyword sorter” function (a method of Core.kwcall on newer versions of Julia) uses conditionals to handle the presence or absence of certain kwargs. Because Zygote will unconditionally generate type unstable code when it encounters branching control flow, this means all calls with keyword arguments will be type unstable under AD.

As you’ve noted, this can be worked around by defining an rrule for the function in question. This is how e.g. sum(...; dims=...) can be type stable with Zygote: there are rules for it in ChainRules.jl. Unfortunately, having an rrule does preclude differentiating wrt. keyword arguments. @oxinabox has a far more in-depth series of posts about this and possible workarounds at Rrule (or frule) with kwargs.

1 Like

Thanks! So, the function I am differentiating is just an “alias” —

function eval_tree_array(
    ex::AbstractExpression,
    cX::AbstractMatrix,
    operators::Union{AbstractOperatorEnum,Nothing}=nothing;
    kws...,
)
    return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end

In other words it does not handle any kwargs explicitly. The function it passes the kwargs, which does unpack them, has an rrule.

In this case are there any more general workarounds I can exploit, rather than defining an rrule for each?

This code is part an interface that users can add methods to, so I would like it if they could avoid adding their own rrule for every method. e.g., another example is ParametricExpression: DynamicExpressions.jl/src/ParametricExpression.jl at 69b20d5cbb29252e58c61f3af0b7de12e73e45fc · SymbolicML/DynamicExpressions.jl · GitHub which again is just an alias.

Unfortunately, it doesn’t have to. The keyword argument handling logic has no shortage of other branches. To demonstrate with code_warntype (Cthulhu inexplicably breaks on this example):

julia> foo(x; kws...) = sum(x; kws...);

julia> @code_warntype foo(ones(1); dims=1)
MethodInstance for Core.kwcall(::@NamedTuple{dims::Int64}, ::typeof(foo), ::Vector{Float64})
  from kwcall(::NamedTuple, ::typeof(foo), x) @ Main REPL[10]:1
Arguments
  _::Core.Const(Core.kwcall)
  @_2::@NamedTuple{dims::Int64}
  @_3::Core.Const(foo)
  x::Vector{Float64}
Locals
  kws...::Base.Pairs{Symbol, Int64, Tuple{Symbol}, @NamedTuple{dims::Int64}}
Body::Vector{Float64}
1 ─      (kws... = Base.pairs(@_2))
│   %2 = Main.:(var"#foo#3")(kws...::Core.PartialStruct(Base.Pairs{Symbol, Int64, Tuple{Symbol}, @NamedTuple{dims::Int64}}, Any[@NamedTuple{dims::Int64}, Core.Const((:dims,))]), @_3, x)::Vector{Float64}
└──      return %2


julia> @code_warntype Main.:(var"#foo#1")(pairs((; dims=1)), foo, ones(1))
MethodInstance for var"#foo#1"(::Base.Pairs{Symbol, Int64, Tuple{Symbol}, @NamedTuple{dims::Int64}}, ::typeof(foo), ::Vector{Float64})
  from var"#foo#1"(kws::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}}, ::typeof(foo), x) @ Main REPL[2]:1
Arguments
  #foo#1::Core.Const(var"#foo#1")
  kws::Base.Pairs{Symbol, Int64, Tuple{Symbol}, @NamedTuple{dims::Int64}}
  @_3::Core.Const(foo)
  x::Vector{Float64}
Body::Vector{Float64}
1 ─      nothing
│   %2 = Base.NamedTuple()::Core.Const(NamedTuple())
│   %3 = Base.merge(%2, kws)::@NamedTuple{dims::Int64}
│   %4 = Base.isempty(%3)::Core.Const(false)
└──      goto #3 if not %4  # <-- notice the branch!
2 ─      Core.Const(:(Main.sum(x)))
└──      Core.Const(:(return %6))
3 ┄ %8 = Core.kwcall(%3, Main.sum, x)::Vector{Float64}
└──      return %8

Beyond the suggestions in the linked thread (create Zygote-specific rules where necessary or tweak the AD itself), I can’t think of many. You could consider creating custom holder type(s) for common keyword args. Generalized sufficiently and with rrules/non-branching logic, those could work around current handling of built-in keyword dispatch. But if the library interface doesn’t allow for this, it’s a non-starter.

Thanks. I guess I could even allow for a NamedTuple to be passed as an optional final argument, with default (;), and that would also work, right? I like the idea of a custom holder type. Although if a user defines both a custom AbstractExpressionNode and a custom AbstractExpression, they wouldn’t be able to customize the keywords available. So maybe just any NamedTuple would work nicely.

I’ve ran into other issues with keyword arguments before, where they affect specialization: The use of `NamedTuple` in `Core.kwcall` prevents specialization of keyword arguments · Issue #54661 · JuliaLang/julia · GitHub.

I feel like maybe people may want to avoid keyword arguments in general? They hurt specialization and now it sounds like they force runtime dispatch for autodiff… It’s too bad, because they are awfully convenient :confused:

In fairness I think this is more of a Zygote-specific problem. Generating type stable reverse AD transformed code in the presence of branching is completely impractical (though theoretically possible) when you have to deal with the constraints Zygote has. Newer ADs have addressed this by making more effective use of newer internals and by imposing constraints on user input. In particular, I think GitHub - compintell/Tapir.jl “does right” a lot of the ideas and mechanisms originally proposed and developed for Zygote.