How can I reduce memory consumption when using Zygote and Integrals together

Let’s start with a simple example:

julia> using Integrals


julia> using SciMLSensitivity

julia> using BenchmarkTools

julia> using Zygote

julia> f(x,h)=sin(x)*exp(h*cos(x))
f (generic function with 1 method)

julia> g(h)=solve(IntegralProblem(f,0, pi,h),QuadGKJL(),abstol=1e-16)[1]
g (generic function with 1 method)

julia> x0,h0=0.8,0.5
(0.8, 0.5)

julia> @btime f(x0,h0)
  38.130 ns (1 allocation: 16 bytes)
1.0163018820591236

julia> @btime g(h0)
  5.396 μs (83 allocations: 2.59 KiB)
2.084381221974989

julia> @btime gradient(g,h0)
  266.164 μs (832 allocations: 32.97 KiB)
(0.34174141687554416,)

We can see that a function that takes only 16 bytes to run consumes nearly 200 times more memory when integrating. And the amount of memory used to calculate gradients has increased more than tenfold. Of course, these operations don’t actually consume much memory.But when I have loops, it’s going to take a lot of memory to compute the gradient. Something like the following

using Integrals
using SciMLSensitivity
using Zygote

function test(t::Float64,D::Int64)
 
    T = Zygote.Buffer(zeros(D,D), D,D)

    for i in 1:D
        for j in 1:D
            T[i,j] = 2*pi * solve(IntegralProblem((x,t)->sin(i*x)*exp(t*cos(j*x)),0, pi,t),QuadGKJL(),abstol=1e-16)[1]
        end
    end
    T=copy(T)
    return sum(T)
end

function main(D::Int64)

    t0 = 1.5
    f(t) = test(t,D)
    @time f(t0) 
    @time g = gradient(t -> f(t), t0)[1]
    return nothing
end

@time main(100)

The output of @time f(t0)

1298.039107 seconds (15.45 G allocations: 319.864 GiB, 4.15% gc time, 0.42% compilation time)

Gc is triggered at the end of the inner loop, so the f(t) function does not require as much machine memory as shown here. But things were different when I used Zygote.Reverse-mode AD typically uses memory proportional to the number of operations in the program, So the maximum machine memory required to take the derivative is more than the memory shown above. When I have a lot of loops, the amount of machine memory required to differentiate is very large. How can I optimize my memory footprint in this case?