I’m trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now Zygote.gradient
is inferring Any
as a return value and I can’t figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It’s only the outermost wrapper function that fails to infer.
Context – for reference you can get the same version of the package I’m debugging with:
add https://github.com/SymbolicML/DynamicExpressions.jl#c9eaedf63e36a227db702b4ea3257938892447d5
Basically I have this recursive tree structure Node{T}
. I don’t want Zygote to try to walk through the whole tree, so instead, I have this custom NodeTangent
type (in src/ChainRules.jl
)
struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
tree::N
gradient::A
end
Base.:+(a::NodeTangent, b::NodeTangent) = NodeTangent(a.tree, a.gradient + b.gradient)
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()
I then have a chain rule for evaluation which returns this NodeTangent
, defined as follows:
function CRC.rrule(
::typeof(eval_tree_array),
tree::AbstractExpressionNode,
X::AbstractMatrix,
operators::OperatorEnum;
kws...,
)
primal, complete = eval_tree_array(tree, X, operators; kws...)
if !complete
primal .= NaN
end
return (primal, complete), EvalPullback(tree, X, operators)
end
# Wrap in struct rather than closure to ensure variables are boxed
struct EvalPullback{N,A,O} <: Function
tree::N
X::A
operators::O
end
# TODO: Preferable to use the primal in the pullback somehow
function (e::EvalPullback)((dY, _))
_, dX_constants_dY, complete = eval_grad_tree_array(
e.tree, e.X, e.operators; variable=Val(:both)
)
if !complete
dX_constants_dY .= NaN
end
nfeatures = size(e.X, 1)
dX_dY = @view dX_constants_dY[1:nfeatures, :]
dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]
dtree = NodeTangent(
e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
)
dX = dX_dY .* reshape(dY, 1, length(dY))
return (NoTangent(), dtree, dX, NoTangent())
end
This actually works fine. I can get derivatives that are correct and inference seems good:
using DynamicExpressions
using Zygote
const operators = OperatorEnum(;
binary_operators=[+, -, *],
unary_operators=[cos],
)
x1 = Node{Float64}(feature=1)
x2 = Node{Float64}(feature=2)
tree = x1 * cos(x2 - 3.2)
julia> Test.@inferred Zygote.gradient(
t -> eval_tree_array(t, ones(2, 1), operators)[1][1],
tree
)
(NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]),)
and it returns a NodeTangent
which prevents Zygote from walking the tree.
However, when I then try to use my new Expression
type, which is nothing but a Node{T}
plus a named tuple of operators and variable names:
struct Expression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
tree::N
metadata::Metadata{D}
end
I no longer get this successful inference. Here is the wrapper method of the evaluation:
function eval_tree_array(
ex::AbstractExpression,
cX::AbstractMatrix,
operators::Union{AbstractOperatorEnum,Nothing}=nothing;
kws...,
)
return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end
So it basically just unpacks ex -> ex.tree
and ex -> ex.metadata.operators
.
Now, say that I try to take the gradient of this instead. Unlike the internal eval_tree_array
call, this one I do not define a custom chain rule for (since the wrapper call is simple).
julia> ex = Expression(tree; operators, variable_names=["x1", "x2"])
x1 * cos(x2 - 3.2)
julia> Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
((tree = NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]), metadata = nothing),)
julia> Test.@inferred Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
ERROR: return type Tuple{@NamedTuple{tree::NodeTangent{Float64, Node{Float64}, Vector{Float64}}, metadata::Nothing}} does not match inferred return type Tuple{Any}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] top-level scope
@ REPL[25]:1
Even though all eval_tree_array(::AbstractExpression
is doing is some getproperty
calls before passing to a call – which I know works – inference on this wrapper fails.
Questions:
- Any guesses as to what the issue is from?
- Do I need to define a custom tangent type for
Expression
, and what’s the actual interface forAbstractTangent
?- For the record, I did try creating a
zero_tangent(::Expression)
, but this didn’t seem to fix the issue. Maybe there’s some other function I need to define? - Perhaps I need to declare
NodeTangent
for some other function symbols so that Zygote doesn’t try descending at the outermost call? And if so, what methods need to be implemented?
- For the record, I did try creating a
- (General) How does one go about debugging type inference issues in Zygote, when type inference on the primal is fine? I can’t seem to use Cthulhu.jl effectively though perhaps I am descending the wrong tree.
(cc @gdalle in case you have any ideas)