Autodiff Enzyme reverse mode sometimes fails on on repeated evaluation

I’m developing a package which has a function defined the following way

function emission_radius(pix::Krang.AbstractPixel, θs::T, isindir, n) where {T}
    α, β = screen_coordinate(pix)
    θo = inclination(pix)
    met = metric(pix)
    isincone = θo ≤ θs ≤ (π-θo) || (π-θo) ≤ θs ≤ θo
    if !isincone#cosθs > abs(cosθo)
        αmin = αboundary(met, θs)
        βbound = (abs(α) >= (αmin + eps(T)) ? βboundary(met, α, θo, θs) : zero(T))
        ((abs(β) + eps(T)) < βbound) && return (T(NaN), true, true, 0)
    end

    τ, _, _, _ = Gθ(pix, θs, isindir, n)

    # is θ̇s increasing or decreasing?
    νθ = !isindir
    if isincone 
        νθ = (θo > θs) ⊻ (n % 2 == 1) 
    end
    # is ṙs increasing or decreasing?
    rs, νr, numreals = emission_radius(pix, τ)

    return rs, νr, νθ, numreals, abs(τ)
end

I however get indeterministic behaviour when using autodiff in reverse mode on this function. Here is an example where I define the function and differentiate through it with the same variables 5 times.

using Enzyme

for _ in 1:5
function intensity_point(x,y)
    θo = π/4
    metric = Krang.Kerr(-0.94);
    px = Krang.IntensityPixel(metric, x, y, θo)
    

    return emission_radius(px, π / 2, true, 0)[1]
end
println(intensity_point(5.0, 4.0))
println(autodiff(ReverseWithPrimal, intensity_point, Active, Active(5.0), Active(4.0)))
end

The resulting output looks like this

5.2302506256226895
 caching call:   %33 = call fastcc double @julia_K_36770(double %22) #95, !dbg !139
 caching call:   %8 = call fastcc double @julia_serf_36810(double %5, double %1) #97, !dbg !102
 caching call:   %27 = call fastcc double @julia_serf_36810(double %22, double %1) #101, !dbg !97
 caching call:   %28 = call fastcc double @julia_K_36770(double %1) #100, !dbg !136
 caching call:   %28 = call fastcc double @julia_K_36770(double %21) #100, !dbg !137
 caching call:   %45 = call fastcc double @julia_K_36770(double %16) #100, !dbg !160
 caching call:   %55 = call fastcc double @julia_atan_36837(double %53) #101, !dbg !173
 caching call:   %9 = call fastcc double @julia_K_36770(double %1) #119, !dbg !106
 caching call:   %10 = call fastcc double @julia_K_36770(double %1) #119, !dbg !106
 caching call:   %17 = call fastcc double @julia_K_36770(double %16) #119, !dbg !122
 caching call:   %18 = call fastcc double @julia_K_36770(double %16) #119, !dbg !122
((NaN, NaN), 5.2302506256226895)
5.2302506256226895
 caching call:   %33 = call fastcc double @julia_K_37083(double %22) #95, !dbg !139
 caching call:   %8 = call fastcc double @julia_serf_37123(double %5, double %1) #97, !dbg !102
 caching call:   %27 = call fastcc double @julia_serf_37123(double %22, double %1) #101, !dbg !97
 caching call:   %28 = call fastcc double @julia_K_37083(double %1) #100, !dbg !136
 caching call:   %28 = call fastcc double @julia_K_37083(double %21) #100, !dbg !137
 caching call:   %45 = call fastcc double @julia_K_37083(double %16) #100, !dbg !160
 caching call:   %55 = call fastcc double @julia_atan_37150(double %53) #101, !dbg !173
 caching call:   %9 = call fastcc double @julia_K_37083(double %1) #119, !dbg !106
 caching call:   %10 = call fastcc double @julia_K_37083(double %1) #119, !dbg !106
 caching call:   %17 = call fastcc double @julia_K_37083(double %16) #119, !dbg !122
 caching call:   %18 = call fastcc double @julia_K_37083(double %16) #119, !dbg !122
((0.8447302895862908, 0.6627147492004235), 5.2302506256226895)
5.2302506256226895
 caching call:   %33 = call fastcc double @julia_K_37396(double %22) #95, !dbg !139
 caching call:   %8 = call fastcc double @julia_serf_37436(double %5, double %1) #97, !dbg !102
 caching call:   %27 = call fastcc double @julia_serf_37436(double %22, double %1) #101, !dbg !97
 caching call:   %28 = call fastcc double @julia_K_37396(double %1) #100, !dbg !136
 caching call:   %28 = call fastcc double @julia_K_37396(double %21) #100, !dbg !137
 caching call:   %45 = call fastcc double @julia_K_37396(double %16) #100, !dbg !160
 caching call:   %55 = call fastcc double @julia_atan_37463(double %53) #101, !dbg !173
 caching call:   %9 = call fastcc double @julia_K_37396(double %1) #119, !dbg !106
 caching call:   %10 = call fastcc double @julia_K_37396(double %1) #119, !dbg !106
 caching call:   %17 = call fastcc double @julia_K_37396(double %16) #119, !dbg !122
 caching call:   %18 = call fastcc double @julia_K_37396(double %16) #119, !dbg !122
((NaN, NaN), 5.2302506256226895)

Enzyme is apparently is capable of taking the derivative sometimes, but not other times. What can I do to get more predictable behaviour out?

Is the function intensity_point fully deterministic to start with?

I don’t know what exactly Krang.Kerr and Krang.IntensityPixel are doing, but it would be the easiest explanation that some randomness leads to different branches and then one could check what goes wrong in these particular branches…

Part of the issue is probably re-defining the function in a for loop. Does it happen if you just define the function once? Why do you need to keep redefining it?

Yes, intensity_point is deterministic. Kerr and IntensityPixel are structs that I use to cache some information.

struct Kerr{T} <: AbstractMetric
    "M = mass"
    mass::T  
    "a = J/M, where J is the angular momentum and M is the mass of the blackhole."
    spin::T
    function Kerr(spin::T) where {T}
        new{T}(one(T), spin)
    end
end
struct IntensityPixel{T} <: AbstractPixel
    metric::Kerr{T}
    screen_coordinate::NTuple{2, T}
    "Radial roots"
    roots::NTuple{4,Complex{T}}
    "Radial antiderivative"
    I0_inf::T
    "Angular antiderivative"
    absGθo_Gθhat::NTuple{2,T}
    "Inclination"
    θo::T
    η::T
    λ::T
    function IntensityPixel(met::Kerr{T}, α, β, θo) where {T}
        tempη = Krang.η(met, α, β, θo)
        tempλ = Krang.λ(met, α, θo)
        roots = Krang.get_radial_roots(met, tempη, tempλ)
        numreals = sum(_isreal2.(roots))
        if (numreals == 2) && (abs(imag(roots[4])) < sqrt(eps(T)))
            roots = (roots[1], roots[4], roots[2], roots[3])
        end
        new{T}(
            met,
            (α, β), 
            roots,
            Krang.Ir_inf(met, roots), 
            Krang._absGθo_Gθhat(met, θo, tempη, tempλ), 
            θo, tempη, tempλ
        )
    end
end

Defining the function once causes enzyme to either always return NaN or the an actual number. I was redefining the function to show that different answers for the derivative can be returned, even though the function itself always returns the same value.

Open an issue on Enzyme with the MWE, Enzyme version, and Julia/OS version?

Sure thing

Created an issue with an MWE on github.