Unexpected performance mismatch in gradients for "compiled-tape-in-tape" experiment

Hi all,

I’m working on a way to make selectively compiling parts of a ReverseDiff tape easier and more accessible. The goal is to enable fearless compilation of functions with branches.

One problem I’ve run into is related to the caching of intermediate values. Ideally, embedding a tape inside another tape shouldn’t cause any slowdown (besides perhaps the pointer lookup to find the other tape), but I’m finding that I’m not getting the speedup I’m looking for, in fact despite having fewer allocations my inner compiled tape is slower than a “vanilla” function without it.

Can anyone help track down the performance snafu? Specifically, why does the fully-compiled example (last one) have such a difference in performance (700ns vs 500ns) despite doing the same thing?

Implementation
import AbstractDifferentiation as AD
using ReverseDiff

using ReverseDiff: @grad, compile, GradientTape
import AbstractDifferentiation: primal_value, pullback_function, value_and_pullback_function

struct CachedReverseDiffBackend{F,T,C} <: AD.AbstractBackend # Could also be parametric in backend type
    func::F
    compiled_tape::T
    output_cache::C
    # Constructor to compile the tape given inputs
    function CachedReverseDiffBackend(f::F, x) where {F}
        compiled_tape = compile(GradientTape(f, x)) # pseudo RD code
        output_cache = similar(ReverseDiff.gradient!(compiled_tape, x))
        T = typeof(compiled_tape)
        C = typeof(output_cache)
        return new{F,T,C}(f, compiled_tape, output_cache)
    end
end

const CRDB = CachedReverseDiffBackend # alias for brevity

(b::CRDB)(x) = call_func(b, x)
call_func(b::CRDB, x) = b.func(x)

function call_func(b::CRDB, x::ReverseDiff.TrackedArray)
    return ReverseDiff.track(call_func, b, x)
end

@grad function call_func(b::CRDB, x)
    return value_and_pullback_function(b, x)
end

primal_value(::CRDB, xs, _) = primal_value(xs) # is this ok?

function value_and_pullback_function(cb::CRDB, x)
    xv = ReverseDiff.value(x)
    yv = cb.func(xv)
    function pullback_f(Δ)
        ReverseDiff.gradient!(cb.output_cache, cb.compiled_tape, xv)
        cb.output_cache .*= Δ
        (cb.output_cache, )
    end
    return yv, pullback_f
end
Code exemplifying the problem
using BenchmarkTools
# using Cthulhu
# using Profile
# using PProf

# The function we compile, on some example inputs
g(xs) = sum(abs2, xs)
xs = [1.0, 2.0, 3.0]

# Must be declared const otherwise type unstable when called from inside other functions
const crdb = CRDB(g, xs)

# An example outer function
f_nocompile(xs) = 2g(xs)
f_compile(xs) = 2crdb(xs)

# Primal values calculated at same speed with 0 allocs each
@btime f_nocompile($xs)
@btime f_compile($xs)

# Test the gradients of inner function alone to see if performance drop comes from here
gt = compile(GradientTape(g, xs)) # RD code
out = similar(ReverseDiff.gradient(g, xs .+ 1))
@btime ReverseDiff.gradient!($out, $g, $xs) # Slowest, has allocations
@btime ReverseDiff.gradient!($out, $gt, $xs) # Fastest, no allocs
@btime ReverseDiff.gradient!($out, $(crdb.compiled_tape), $xs) # Fastest, no allocs

# Gradients calculated at wildly different speeds/allocs
out_f = similar(ReverseDiff.gradient(f_nocompile, xs))
@btime ReverseDiff.gradient!($out_f, $f_nocompile, $xs)
@btime ReverseDiff.gradient!($out_f, $f_compile, $xs)

# Even when both functions are fully compiled onto a tape
fnc_tape = compile(GradientTape(f_nocompile, xs))
fc_tape = compile(GradientTape(f_compile, xs))
@btime ReverseDiff.gradient!($out_f, $fnc_tape, $xs)
@btime ReverseDiff.gradient!($out_f, $fc_tape, $xs)