Using your suggestions I finally changed the function in the Lux.Chain - as it holds information about the layers and it invokes the Lux.apply I added Zygode checkpointed to Lux apply inside it like that :
@generated function applychain(
layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
N = length(fields)
x_symbols = vcat([:x], [gensym() for _ in 1:N])
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = Zygote.checkpointed(Lux.apply,
layers.$(fields[i]), $(x_symbols[i]), ps.$(fields[i]), st.$(fields[i])))
for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
push!(calls, :(return $(x_symbols[N + 1]), st))
res= Expr(:block, calls...)
return res
end
So now all layers in this chain have gradient checkpointing.