Hi,
I am trying to calculate the hessian of a function using DifferentitationInterface.jl with Mooncake as the backend. The calculation of the jacobian this way is very fast - both in the preparation stage and the actual evaluation (I have been unable to use Enzyme.jl to do this - In forward mode, it throws an error pointing to a matrix multiplication implemented by matmul! from LinearAlgebra and in reverse mode, I waited for over an hour before killing the process), so I am hopeful that I can keep using Mooncake to get the hessian as well.
To boil down the issue, I have come up with a simple function:
function real_laughlin_wavefunction_1(X::Vector{Float32})
N = div(length(X), 2)
res = zero(Float64)
for i in 1:N-1
@simd for j in i+1:N
d = sin(acos(cos(X[2*i-1])*cos(X[2*j-1]) + cos(X[2*i] - X[2*j]) * sin(X[2*i-1])*sin(X[2*j-1]))/2)
res += 1.50 * log(abs2(d))
end
end
return res
end
When I try to prepare the hessian, as follows:
prep_hess_1 = DI.prepare_hessian(real_laughlin_wavefunction_1, backend, X)
I get the following error:
[ Info: Compiling rule for Tuple{typeof(DifferentiationInterface.shuffled_gradient), Vector{Float32}, typeof(real_laughlin_wavefunction_1), ADTypes.AutoMooncake{Mooncake.Config}, DifferentiationInterface.Rewrap{0, Tuple{}}} in debug mode. Disable for best performance.
ERROR: MooncakeRuleCompilationError: an error occured while Mooncake was compiling a rule to differentiate something. If the `caused by` error message below does not make it clear to you how the problem can be fixed, please open an issue at github.com/compintell/Mooncake.jl describing your problem.
To replicate this error run the following:
Mooncake.build_rrule(Mooncake.MooncakeInterpreter(), Tuple{Mooncake.var"##prepare_pullback_cache#697", Base.Pairs{Symbol, Bool, Tuple{Symbol, Symbol}, @NamedTuple{debug_mode::Bool, silence_debug_messages::Bool}}, typeof(Mooncake.prepare_pullback_cache), Function, Vararg{Any}}; debug_mode=true)
Note that you may need to `using` some additional packages if not all of the names printed in the above signature are available currently in your environment.
Stacktrace:
[1] build_rrule(interp::Mooncake.MooncakeInterpreter{…}, sig_or_mi::Core.MethodInstance; debug_mode::Bool, silence_debug_messages::Bool)
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1108
[2] build_rrule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1049 [inlined]
[3] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1799
[4] LazyDerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
[5] RRuleZeroWrapper
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
[6] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[7] prepare_gradient
@ ~/.julia/packages/DifferentiationInterface/srtnM/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:118 [inlined]
[8] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[9] DerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
[10] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[11] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1800
[12] LazyDerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
[13] RRuleZeroWrapper
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
[14] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[15] gradient
@ ~/.julia/packages/DifferentiationInterface/srtnM/src/fallbacks/no_prep.jl:48 [inlined]
[16] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[17] DerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
[18] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[19] _build_rule!(rule::Mooncake.LazyDerivedRule{…}, args::Tuple{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1800
[20] LazyDerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:1795 [inlined]
[21] RRuleZeroWrapper
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
[22] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[23] shuffled_gradient
@ ~/.julia/packages/DifferentiationInterface/srtnM/src/first_order/gradient.jl:129 [inlined]
[24] (::Tuple{…})(none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…}, none::Mooncake.CoDual{…})
@ Base.Experimental ./<missing>:0
[25] DerivedRule
@ ~/.julia/packages/Mooncake/ST1qF/src/interpreter/s2s_reverse_mode_ad.jl:938 [inlined]
[26] (::Mooncake.DebugRRule{…})(::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…}, ::Mooncake.CoDual{…})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/debug_mode.jl:89
[27] prepare_pullback_cache(::Function, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
@ Mooncake ~/.julia/packages/Mooncake/ST1qF/src/interface.jl:193
[28] prepare_pullback_cache
@ ~/.julia/packages/Mooncake/ST1qF/src/interface.jl:183 [inlined]
[29] prepare_pullback(::typeof(DifferentiationInterface.shuffled_gradient), ::ADTypes.AutoMooncake{…}, ::Vector{…}, ::Tuple{…}, ::DifferentiationInterface.FunctionContext{…}, ::DifferentiationInterface.BackendContext{…}, ::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/srtnM/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:12
[30] _prepare_hvp_aux(::DifferentiationInterface.ReverseOverReverse, ::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{…}, ::Vector{…}, ::Tuple{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hvp.jl:435
[31] prepare_hvp
@ ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hvp.jl:74 [inlined]
[32] _prepare_hessian_aux(::DifferentiationInterface.BatchSizeSettings{…}, ::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{…}, ::Vector{…})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hessian.jl:99
[33] prepare_hessian(::typeof(real_laughlin_wavefunction_1), ::ADTypes.AutoMooncake{Mooncake.Config}, ::Vector{Float32})
@ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/srtnM/src/second_order/hessian.jl:83
[34] top-level scope
@ REPL[31]:1
caused by: StackOverflowError:
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Mooncake/ST1qF/src/tangents.jl:435 [inlined]
[2] macro expansion
@ ./none:0 [inlined]
[3] tangent_type(::Type{Core.CodeInstance})
@ Mooncake ./none:0
--- the last 3 lines are repeated 79982 more times ---
[239950] macro expansion
@ ~/.julia/packages/Mooncake/ST1qF/src/tangents.jl:435 [inlined]
[239951] macro expansion
@ ./none:0 [inlined]
Some type information was truncated. Use `show(err)` to see complete types.
What does this error type mean? This is the exact same error I get when I try to calculate the hessian for my actual function as well.