Avoid storing intermediate results from the forward pass by default

Using your suggestions I finally changed the function in the Lux.Chain - as it holds information about the layers and it invokes the Lux.apply I added Zygode checkpointed to Lux apply inside it like that :

@generated function applychain(
    layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
    N = length(fields)
    x_symbols = vcat([:x], [gensym() for _ in 1:N])
    st_symbols = [gensym() for _ in 1:N]
    calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = Zygote.checkpointed(Lux.apply,
                layers.$(fields[i]), $(x_symbols[i]), ps.$(fields[i]), st.$(fields[i])))
            for i in 1:N]
    push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
    push!(calls, :(return $(x_symbols[N + 1]), st))
    res= Expr(:block, calls...)
    return res
end

So now all layers in this chain have gradient checkpointing.