Avoid storing intermediate results from the forward pass by default

Hello I tried to approach this from chainrules perspective - as suggested by @Tomas_Pevny if I would want to get all functions it will be a mess but what I really need is to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?