I’m trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now `Zygote.gradient`

is inferring `Any`

as a return value and I can’t figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It’s only the outermost wrapper function that fails to infer.

Context – for reference you can get the same version of the package I’m debugging with:

```
add https://github.com/SymbolicML/DynamicExpressions.jl#c9eaedf63e36a227db702b4ea3257938892447d5
```

Basically I have this recursive tree structure `Node{T}`

. I don’t want Zygote to try to walk through the whole tree, so instead, I have this custom `NodeTangent`

type (in `src/ChainRules.jl`

)

```
struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
tree::N
gradient::A
end
Base.:+(a::NodeTangent, b::NodeTangent) = NodeTangent(a.tree, a.gradient + b.gradient)
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()
```

I then have a chain rule for evaluation which returns this `NodeTangent`

, defined as follows:

```
function CRC.rrule(
::typeof(eval_tree_array),
tree::AbstractExpressionNode,
X::AbstractMatrix,
operators::OperatorEnum;
kws...,
)
primal, complete = eval_tree_array(tree, X, operators; kws...)
if !complete
primal .= NaN
end
return (primal, complete), EvalPullback(tree, X, operators)
end
# Wrap in struct rather than closure to ensure variables are boxed
struct EvalPullback{N,A,O} <: Function
tree::N
X::A
operators::O
end
# TODO: Preferable to use the primal in the pullback somehow
function (e::EvalPullback)((dY, _))
_, dX_constants_dY, complete = eval_grad_tree_array(
e.tree, e.X, e.operators; variable=Val(:both)
)
if !complete
dX_constants_dY .= NaN
end
nfeatures = size(e.X, 1)
dX_dY = @view dX_constants_dY[1:nfeatures, :]
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]
dtree = NodeTangent(
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
)
dX = dX_dY .* reshape(dY, 1, length(dY))
return (NoTangent(), dtree, dX, NoTangent())
end
```

This actually works fine. I can get derivatives that are correct and inference seems good:

```
using DynamicExpressions
using Zygote
const operators = OperatorEnum(;
binary_operators=[+, -, *],
unary_operators=[cos],
)
x1 = Node{Float64}(feature=1)
x2 = Node{Float64}(feature=2)
tree = x1 * cos(x2 - 3.2)
```

```
julia> Test.@inferred Zygote.gradient(
t -> eval_tree_array(t, ones(2, 1), operators)[1][1],
tree
)
(NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]),)
```

and it returns a `NodeTangent`

which prevents Zygote from walking the tree.

However, when I then try to use my new `Expression`

type, which is nothing but a `Node{T}`

plus a named tuple of operators and variable names:

```
struct Expression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
tree::N
metadata::Metadata{D}
end
```

I no longer get this successful inference. Here is the wrapper method of the evaluation:

```
function eval_tree_array(
ex::AbstractExpression,
cX::AbstractMatrix,
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
kws...,
)
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end
```

So it basically just unpacks `ex -> ex.tree`

and `ex -> ex.metadata.operators`

.

Now, say that I try to take the gradient of this instead. Unlike the internal `eval_tree_array`

call, this one I do not define a custom chain rule for (since the wrapper call is simple).

```
julia> ex = Expression(tree; operators, variable_names=["x1", "x2"])
x1 * cos(x2 - 3.2)
julia> Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
((tree = NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]), metadata = nothing),)
julia> Test.@inferred Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
ERROR: return type Tuple{@NamedTuple{tree::NodeTangent{Float64, Node{Float64}, Vector{Float64}}, metadata::Nothing}} does not match inferred return type Tuple{Any}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] top-level scope
@ REPL[25]:1
```

Even though all `eval_tree_array(::AbstractExpression`

is doing is some `getproperty`

calls before passing to a call – which I know works – inference on this wrapper fails.

## Questions:

- Any guesses as to what the issue is from?
- Do I need to define a custom tangent type for
`Expression`

, and what’s the actual interface for`AbstractTangent`

?- For the record, I did try creating a
`zero_tangent(::Expression)`

, but this didn’t seem to fix the issue. Maybe there’s some other function I need to define? - Perhaps I need to declare
`NodeTangent`

for some other function symbols so that Zygote doesn’t try descending at the outermost call? And if so, what methods need to be implemented?

- For the record, I did try creating a
- (General) How does one go about debugging type inference issues in Zygote, when type inference on the primal is fine? I can’t seem to use Cthulhu.jl effectively though perhaps I am descending the wrong tree.

(cc @gdalle in case you have any ideas)