Avoid storing intermediate results from the forward pass by default

Hello,

Is it possible in Zygote to change the default behavior for computing derivatives during backpropagation to use gradient checkpointing?

I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation. Additionally, I have created some custom rules with ChainRules for some functions that use Enzyme and custom gradient checkpointing.

I know that Zygote supports gradient checkpointing via Zygote.checkpointed, but is there a way to use it by default? Specifically, can I overwrite the main backpropagation function and wrap it with Zygote.checkpointed[1] or something similar?

1)Utilities · Zygote

1 Like

This would be difficult, since checkpointing requires to recompute the intermediate results, which means that you compute the forward pass twice. But Zygote contains support for checkpointing.

https://fluxml.ai/Zygote.jl/dev/adjoints/#Checkpointing-1

3 Likes

This question inspired me to revisit the following issue. Perhaps we could support checkpointing in ADTypes.jl and DifferentiationInterface.jl? Don’t hesitate to pitch in.

1 Like

Hello I tried to approach this from chainrules perspective - as suggested by @Tomas_Pevny if I would want to get all functions it will be a mess but what I really need is to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?

Can you try Zygote.checkpointed instead? What you’re doing here is type piracy, since you’re defining a method without owning the function or the argument types

Thanks for answer @gdalle ! I do not see how to use Zygote.checkpointed in this situation. In Lux I just define list of structs - that later are invoked by the library. So I can use Zygote checkpointed on my function but not on Lux functions, at least when I try sth like

function Lux.apply(l::TensorOpLayer_str, x, ps, st::NamedTuple)
    return Zygote.checkpointed(Lux.apply(l, x, ps, st))
end

It is also Type piracy and also lead to stackoverflow :slight_smile:

I am not at all familiar with Lux. I more used to Flux.

Can you show the stack trace please?

I’ve never used it myself but I don’t think you need to redefine a method? Just use Zygote.checkpointed(Lux.apply, args...) in your training loop?

1 Like

Thanks for idea! But lux.apply is not used explicitly by the user per layer ; it is invoked inside the library; as ir is used on each layer and in each layer container .

stack

ERROR: StackOverflowError:
Stacktrace:
     [1] macro expansion
       @ /usr/local/share/julia/packages/CUDA/Tl08O/lib/utils/call.jl:218 [inlined]
     [2] macro expansion
       @ /data/packages/cuTENSOR/uwns2/src/libcutensor.jl:430 [inlined]
     [3] #50
       @ /usr/local/share/julia/packages/CUDA/Tl08O/lib/utils/call.jl:35 [inlined]
     [4] retry_reclaim
       @ /usr/local/share/julia/packages/CUDA/Tl08O/src/memory.jl:434 [inlined]
     [5] check
       @ /data/packages/cuTENSOR/uwns2/src/libcutensor.jl:24 [inlined]
     [6] cutensorCreatePlan
       @ /usr/local/share/julia/packages/CUDA/Tl08O/lib/utils/call.jl:34 [inlined]
     [7] cuTENSOR.CuTensorPlan(desc::Ptr{…}, pref::Ptr{…}; workspacePref::cuTENSOR.cutensorWorksizePreference_t)
       @ cuTENSOR /data/packages/cuTENSOR/uwns2/src/types.jl:160
     [8] CuTensorPlan
       @ /data/packages/cuTENSOR/uwns2/src/types.jl:149 [inlined]
     [9] plan_contraction(A::AbstractArray, Ainds::Vector{…}, opA::cuTENSOR.cutensorOperator_t, B::AbstractArray, Binds::Vector{…}, opB::cuTENSOR.cutensorOperator_t, C::AbstractArray, Cinds::Vector{…}, opC::cuTENSOR.cutensorOperator_t, opOut::cuTENSOR.cutensorOperator_t; jit::cuTENSOR.cutensorJitMode_t, workspace::cuTENSOR.cutensorWorksizePreference_t, algo::cuTENSOR.cutensorAlgo_t, compute_type::Nothing)
       @ cuTENSOR /data/packages/cuTENSOR/uwns2/src/operations.jl:340
    [10] plan_contraction
       @ /data/packages/cuTENSOR/uwns2/src/operations.jl:301 [inlined]
    [11] #contract!#83
       @ /data/packages/cuTENSOR/uwns2/src/operations.jl:272 [inlined]
    [12] contract!(alpha::Number, A::AbstractArray, Ainds::Vector{…}, opA::cuTENSOR.cutensorOperator_t, B::AbstractArray, Binds::Vector{…}, opB::cuTENSOR.cutensorOperator_t, beta::Number, C::AbstractArray, Cinds::Vector{…}, opC::cuTENSOR.cutensorOperator_t, opOut::cuTENSOR.cutensorOperator_t)
       @ cuTENSOR /data/packages/cuTENSOR/uwns2/src/operations.jl:259
    [13] tensorcontract!(C::StridedViews.StridedView{…}, A::StridedViews.StridedView{…}, pA::Tuple{…}, conjA::Bool, B::StridedViews.StridedView{…}, pB::Tuple{…}, conjB::Bool, pAB::Tuple{…}, α::One, β::Zero, backend::TensorOperations.cuTENSORBackend, allocator::TensorOperations.CUDAAllocator{…})
       @ TensorOperationscuTENSORExt /data/packages/TensorOperations/Dlx7i/ext/TensorOperationscuTENSORExt.jl:228
    [14] tensorcontract!(C::CuArray{…}, A::CuArray{…}, pA::Tuple{…}, conjA::Bool, B::CuArray{…}, pB::Tuple{…}, conjB::Bool, pAB::Tuple{…}, α::One, β::Zero, backend::TensorOperations.cuTENSORBackend, allocator::TensorOperations.CUDAAllocator{…})
       @ TensorOperationscuTENSORExt /data/packages/TensorOperations/Dlx7i/ext/TensorOperationscuTENSORExt.jl:96
    [15] (::TensorOpLayer_str)(x::CuArray{…}, ps::@NamedTuple{…}, st::@NamedTuple{…})
       @ Main /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/lin_sampl/model/util_layers.jl:52
    [16] apply
       @ /usr/local/share/julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
    [17] rrule
       @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/lin_sampl/model/util_layers.jl:65 [inlined]
    [18] rrule
       @ /usr/local/share/julia/packages/ChainRulesCore/I1EbV/src/rules.jl:134 [inlined]
    [19] chain_rrule
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:223 [inlined]
    [20] macro expansion
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
    [21] _pullback
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87 [inlined]
    [22] pullback
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90 [inlined]
    [23] pullback
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
    [24] (::var"#pullback_checkpointed#173"{…})(Δy::Tangent{…})
       @ Main /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/lin_sampl/model/util_layers.jl:68
    [25] (::Zygote.ZBack{var"#pullback_checkpointed#173"{…}})(dy::Tangent{Any, Tuple{…}})
       @ Zygote /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211
    [26] (::Zygote.var"#75#76"{Zygote.ZBack{…}})(Δ::Tangent{Any, Tuple{…}})
       @ Zygote /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
    [27] (::var"#pullback_checkpointed#173"{…})(Δy::Tangent{…})
       @ Main /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/lin_sampl/model/util_layers.jl:69
--- the last 3 lines are repeated 7123 more times ---
 [21397] ZBack
       @ /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [21398] loss_function_dummy
       @ /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/test_util_layers.jl:64 [inlined]
 [21399] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float32, Nothing, Nothing})
       @ Zygote /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [21400] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Tuple{Float32, Nothing, Nothing})
       @ Zygote /usr/local/share/julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [21401] compute_gradients(::AutoZygote, objective_function::var"#loss_function_dummy#174", data::CuArray{…}, ts::Lux.Training.TrainState{…})
       @ LuxZygoteExt /usr/local/share/julia/packages/Lux/a2Wcp/ext/LuxZygoteExt/training.jl:5
 [21402] single_train_step!(backend::AutoZygote, obj_fn::var"#loss_function_dummy#174", data::CuArray{…}, ts::Lux.Training.TrainState{…})
       @ Lux.Training /usr/local/share/julia/packages/Lux/a2Wcp/src/helpers/training.jl:281
 [21403] single_train_step!(::AutoZygote, ::Vararg{Any}; kwargs::@Kwargs{})
       @ Lux.Experimental ./deprecated.jl:105
 [21404] single_train_step!(::AutoZygote, ::Vararg{Any})
       @ Lux.Experimental ./deprecated.jl:103
 [21405] test3()
       @ Main /workspaces/superVoxelJuliaCode_lin_sampl/superVoxelJuliaCode/src/tests/test_util_layers.jl:67
Some type information was truncated. Use `show(err)` to see complete types.

I think (but of course not sure) that you introduce the recursion in your pullback.
You have written the rrule as

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

and I think the recursion is caused by this line

 y, pb =Zygote.pullback(Lux.apply,l, x, ps, st)

because this line will call rrule (if you peek into the function) for Lux.apply,l, x, ps, st.

In the implementation I have sent you, I have created a new type, which wrapped the original function and therefore I have avoided the recursion.

It is true that I had to add the checkpointing manually to the code, but I think this is necessary. I also do not think performance-wise you want to checkpoint every function, but rather bigger pieces of code. In my case, I was checkpointing LLama2 and I have checkpointed transformer blocks.

thanks @Tomas_Pevny what do you mean by implementation you had sent ; you mean Custom Adjoints · Zygote ?
I suppose you refer to some other code with LLLM usage but I can not find the link to what you are referring to

I meant this one Adding support for checkpointing · Issue #149 · chengchingwen/Transformers.jl · GitHub

1 Like

If you want to do this, try something like Dispatching on Custom Input Types | Lux.jl Docs

struct CheckpointMe
    x
end

function Lux.apply(l::AbstractLuxLayer, x::CheckpointMe, ps, st::NamedTuple)
    Zygote.checkpointed(Lux.apply, l, x.x, ps, st)
end

But @gdalle is correct, you should call it on Lux.apply. It will forward your calls to your callable struct

1 Like

Using your suggestions I finally changed the function in the Lux.Chain - as it holds information about the layers and it invokes the Lux.apply I added Zygode checkpointed to Lux apply inside it like that :

@generated function applychain(
    layers::NamedTuple{fields}, x, ps, st::NamedTuple{fields}) where {fields}
    N = length(fields)
    x_symbols = vcat([:x], [gensym() for _ in 1:N])
    st_symbols = [gensym() for _ in 1:N]
    calls = [:(($(x_symbols[i + 1]), $(st_symbols[i])) = Zygote.checkpointed(Lux.apply,
                layers.$(fields[i]), $(x_symbols[i]), ps.$(fields[i]), st.$(fields[i])))
            for i in 1:N]
    push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
    push!(calls, :(return $(x_symbols[N + 1]), st))
    res= Expr(:block, calls...)
    return res
end

So now all layers in this chain have gradient checkpointing.