Hessian preparation using DifferentationInterface and Mooncake throws a stackoverflow error

Hi,
I am trying to calculate the hessian of a function using DifferentitationInterface.jl with Mooncake as the backend. The calculation of the jacobian this way is very fast - both in the preparation stage and the actual evaluation (I have been unable to use Enzyme.jl to do this - In forward mode, it throws an error pointing to a matrix multiplication implemented by matmul! from LinearAlgebra and in reverse mode, I waited for over an hour before killing the process), so I am hopeful that I can keep using Mooncake to get the hessian as well.

To boil down the issue, I have come up with a simple function:

function real_laughlin_wavefunction_1(X::Vector{Float32})

    N = div(length(X), 2)
    res = zero(Float64)

    for i in 1:N-1

        @simd for j in i+1:N
    
            d = sin(acos(cos(X[2*i-1])*cos(X[2*j-1]) + cos(X[2*i] - X[2*j]) * sin(X[2*i-1])*sin(X[2*j-1]))/2)
            res += 1.50 * log(abs2(d))

        end
    end

    return res
end

When I try to prepare the hessian, as follows:
prep_hess_1 = DI.prepare_hessian(real_laughlin_wavefunction_1, backend, X)
I get the following error:

[ Info: Compiling rule for Tuple{typeof(DifferentiationInterface.shuffled_gradient), Vector{Float32}, typeof(real_laughlin_wavefunction_1), ADTypes.AutoMooncake{Mooncake.Config}, DifferentiationInterface.Rewrap{0, Tuple{}}} in debug mode. Disable for best performance.
ERROR: MooncakeRuleCompilationError: an error occured while Mooncake was compiling a rule to differentiate something. If the `caused by` error message below does not make it clear to you how the problem can be fixed, please open an issue at github.com/compintell/Mooncake.jl describing your problem.
To replicate this error run the following:

Mooncake.build_rrule(Mooncake.MooncakeInterpreter(), Tuple{Mooncake.var"##prepare_pullback_cache#697", Base.Pairs{Symbol, Bool, Tuple{Symbol, Symbol}, @NamedTuple{debug_mode::Bool, silence_debug_messages::Bool}}, typeof(Mooncake.prepare_pullback_cache), Function, Vararg{Any}}; debug_mode=true)

Note that you may need to `using` some additional packages if not all of the names printed in the above signature are available currently in your environment.

Stacktrace:
  [1] build_rrule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, silence_debug_messages::Bool)
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1108
  [2] build_rrule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1049 [inlined]
  [3] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1799
  [4] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
  [5] RRuleZeroWrapper
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
  [6] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
  [7] prepare_gradient
    @ ~/.julia/packages/DifferentiationInterface/srtnM/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:118 [inlined]
  [8] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
  [9] DerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
 [10] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
 [11] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1800
 [12] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
 [13] RRuleZeroWrapper
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
 [14] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
 [15] gradient
    @ ~/.julia/packages/DifferentiationInterface/srtnM/src/fallbacks/no_prep.jl:48 [inlined]
 [16] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [17] DerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
 [18] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
 [19] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1800
 [20] LazyDerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
 [21] RRuleZeroWrapper
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
 [22] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
 [23] shuffled_gradient
    @ ~/.julia/packages/DifferentiationInterface/srtnM/src/first_order/gradient.jl:129 [inlined]
 [24] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
    @ Base.Experimental ./<missing>:0
 [25] DerivedRule
    @ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
 [26] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
 [27] prepare_pullback_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
    @ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interface.jl:193
 [28] prepare_pullback_cache
    @ ~/.julia/packages/Mooncake/ST1qF/src/interface.jl:183 [inlined]
 [29] prepare_pullback(::typeof(DifferentiationInterface.shuffled_gradient), ::ADTypes.AutoMooncake{…}, ::Vector{…}, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/srtnM/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:12
 [30] _prepare_hvp_aux(::DifferentiationInterface.ReverseOverReverse, ::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{…}, ::Vector{…}, ::Tuple{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hvp.jl:435
 [31] prepare_hvp
    @ ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hvp.jl:74 [inlined]
 [32] _prepare_hessian_aux(::DifferentiationInterface.BatchSizeSettings{…}, ::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{…}, ::Vector{…})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hessian.jl:99
 [33] prepare_hessian(::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{Mooncake.Config}, ::Vector{Float32})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hessian.jl:83
 [34] top-level scope
    @ REPL[31]:1

caused by: StackOverflowError:
Stacktrace:
      [1] macro expansion
        @ ~/.julia/packages/Mooncake/ST1qF/src/tangents.jl:435 [inlined]
      [2] macro expansion
        @ ./none:0 [inlined]
      [3] tangent_type(::Type{Core.CodeInstance})
        @ Mooncake ./none:0
--- the last 3 lines are repeated 79982 more times ---
 [239950] macro expansion
        @ ~/.julia/packages/Mooncake/ST1qF/src/tangents.jl:435 [inlined]
 [239951] macro expansion
        @ ./none:0 [inlined]
Some type information was truncated. Use `show(err)` to see complete types.

What does this error type mean? This is the exact same error I get when I try to calculate the hessian for my actual function as well.

Sadly, Mooncake does not currently support differentiating itself, so it’s not possible to use it to get the Hessian.

Oh. I didn’t know that. Thank you!

Hi @Gattu_Mytraya! Can you provide an MWE for your function? How big is the input? Do you expect its hessian to be sparse?

Hi,

Can you provide an MWE for your function?

My whole code is a bit long, so I am sharing some of the code here.

The final function looks something like this:

function real_laughlin_wavefunction!(X::Vector{Float32}, Y!::Function, Ystorage::Matrix{Float32}, params::Vector{Float32})

    N = div(length(X), 2)
    
    Y!(Ystorage, X, params)

    res = zero(Float64)

    for i in 1:N-1
        
        @inbounds u1 = ComplexF64(Ystorage[1,i], Ystorage[4,i])
        @inbounds v1 = ComplexF64(Ystorage[3,i], Ystorage[2,i])

        @simd for j in i+1:N
            
            @inbounds u2 = ComplexF64(Ystorage[1,j], Ystorage[4,j])
            @inbounds v2 = ComplexF64(Ystorage[3,j], Ystorage[2,j])

            res += 1.50 * log(abs2(u1*v2 - u2*v1))

        end
    end

    return res
end

I would like to take the hessian of real_laughlin_wavefunction! wrt X.

For additional context, the function Y! looks something like this:

    function Y!(Ystorage::Matrix{Float32}, X::Vector{Float32}, params::Vector{Float32})
        
        @simd for i in 1:N

            @inbounds Ystorage[1:4, i] .= get_quaternion_form(X[2*i-1], X[2*i])

        end

        @simd for iter in eachindex(node_initiations)

            @inbounds G.nodes[iter, :] .= params[node_initiations[iter]]
            
        end

        pair_iter = 1
        for i in 1:N
            @simd for j in 1:N

                @inbounds G.edges[1:4, pair_iter] .= get_quaternion_form(X[2*i-1], X[2*i], X[2*j-1], X[2*j])
                
                pair_iter += 1
            end 
        end

        @simd for iter in eachindex(edge_initiations)

            @inbounds G.edges[4 + iter, :] .= params[edge_initiations[iter]]

        end

        for i in 1:num_backflows
            
            update_graph!(G, params[graph_params_iter])

        end

        @inbounds @views Ystorage .+= reshape(params[projection_matrix_iters], 4, G.D1) * G.nodes

        @simd for iter in axes(Ystorage, 2)

            @views Ystorage[:, iter] ./= norm(Ystorage[:, iter])

        end
        
        return
    end

with,

function get_quaternion_form(θ::Float32, ϕ::Float32)

    sθ, cθ = sincos(θ/2)
    sϕ, cϕ = sincos(ϕ/2)

    return cθ*cϕ, -sθ*sϕ, sθ*cϕ, cθ*sϕ

end

function get_quaternion_form(θ1::Float32, ϕ1::Float32, θ2::Float32, ϕ2::Float32)

    sΣθ, cΣθ = sincos((θ1+θ2)/2)

    sΔθ, cΔθ = sincos((θ1-θ2)/2)

    sΔϕ, cΔϕ = sincos((ϕ1-ϕ2)/2)

    return cΔθ * cΔϕ, sΣθ * sΔϕ, -sΔθ * cΔϕ, -cΣθ * sΔϕ

end

and

function update_graph!(G::Graph, graph_params)::Nothing

    # @views @inbounds G.queries_keys .= reshape(graph_params[G.WQ_WK_iters], 2*G.D2, G.D2) * G.edges ### Permutation invariant.
    W = reshape(view(graph_params, G.WQ_WK_iters), 2*G.D2, G.D2)
    mul!(G.queries_keys, W, G.edges)

    queries = reshape(view(G.queries_keys, 1:G.D2, :), G.D2, G.N, G.N)
    keys = reshape(view(G.queries_keys, G.D2+1:2*G.D2, :), G.D2, G.N, G.N)

    G.Φ!(G.messages, G.edges, view(graph_params, G.Φ_params_iters))

    G1 = reshape(G.messages, G.D2, G.N, G.N)

    @simd for i in axes(G.messages, 1)

        @views G1[i, :, :] .*= transpose(queries[i, :, :]) * keys[i, :, :]

    end

    ### Messages have been constructed.

    @inbounds G.node_messages[begin:G.D1, :] .= G.nodes ### (D1, N) matrix

    for i in 1:G.D2
        
        for j in 1:G.N

            acc = zero(Float32)
            
            @simd for k in 1:G.N

                if j != k

                    @inbounds acc += G1[i, k, j]

                end

            end

            G.node_messages[G.D1+i, j] = acc

        end

    end

    @inbounds G.f_nodes!(view(G.nodes, 5:G.D1, :), G.node_messages, view(graph_params, G.f_nodes_params_iters))    

    @inbounds G.edge_messages[1:G.D2, :] .= G.edges ### (D2, N*(N-1)/2) matrix
    @inbounds G.edge_messages[1+G.D2:end, :] .= G.messages ### (D2, N*(N-1)/2) matrix

    @inbounds G.f_edges!(view(G.edges, 5:G.D2, :), G.edge_messages, view(graph_params, G.f_edges_params_iters)) ### f_edges is a multi-layer perceptron. f_edges_params are the parameters of the multi-layer perceptron. G.edges_messages is the input to the multi-layer perceptron with columns corresponding to the edges or different batches. Output of f goes to the hidden features at each edge. First 4 features are the initial edges. These remain fixed.

    return nothing
end

Here, \Phi!, f_nodes! and f_edges! are multi-layer perceptron with fully connected hidden layers, gelu activation everywhere except the output layer (the output layer is of different dimensions than the input in f_nodes! and f_edges!). I wrote my own MLPs, simply because, I wanted them to be non-allocating, I think the ones Flux provides are.

I know this isn’t an MWE, but I think it contains all the important details. But I will also try to think about how to create one that preserves all the essentials.
Alternatively, I can add the whole code to Git Hub and share the link.

How big is the input?

I expect the size of the inputs to be between 8 to 100.

Do you expect its hessian to be sparse?

No, I do not think there will be any non-zero terms in the Hessian.

As a first heuristic, given that your code is complex but your data is small, I’d probably try FiniteDiff.jl for the Hessian. ForwardDiff.jl would be better but the numerous explicit conversions to Float64 in your code will probably make it error.

Hi,
I tried FiniteDiff.jl - Although it works, I think it’s too slow to be usable. To put some numbers, a single function evaluation takes 100e-6s. Gradient evaluation (using Mooncake.jl) takes around 2e-3s (this is still acceptable - I expect to evaluate the gradient around 10^6-10^7 times). But calculating the Hessian using FiniteDiff.jl takes around 0.20s, which is quite a lot for my use case. But I haven’t gone through all the optimization suggested on FiniteDiff.jl’s page.

Hi,
I wanted to post an update: I changed some things to make it compatible with ForwardDiff.jl, but I can now compute Hessian. It takes about 0.02s, which is still on the higher end for my use case, but I will try and see if I can’t make more optimizations.

If your input is very small, you may benefit from using StaticArrays.jl.
Also, if you’re going through DifferentiationInterface.jl, do you exploit preparation?

Hi,
I don’t know the size of my arrays before runtime. Still, I will try to see if I can incorporate StaticArrays by defining a global size variable between different use cases.

I am indeed using DifferentationInterface.jl for preparation.