Hello I tried to approach this from chainrules perspective - as suggested by @Tomas_Pevny if I would want to get all functions it will be a mess but what I really need is to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that
function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
y = Lux.apply(l, x, ps, st)
function pullback_checkpointed(Δy)
y, pb =Zygote.pullback(Lux.apply,l, x, ps, st)
return NoTangent(), pb(Δy)
end
y, pullback_checkpointed
end
Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line
y = Lux.apply(l, x, ps, st)
so I get stack overflow error; how to correct it?