Error using Enzyme to autodiff wrt Flux neural net parameters

I noticed when I use set_runtime_activity(Reverse) as you do, the runtime is about 20x longer than earlier attempts without using Flux. My previous post has this (you both helped you may recall), reposting code here for clarity:

# Test coreloop diff with Enzyme

begin # Packages
    using SpecialFunctions
    using BenchmarkTools
    using SphericalHarmonics
    using Enzyme
end

begin # Functions
    # Coulomb funcs
    function GL(k, r, L)
        return -k*r*sphericalbessely(L, k*r)
    end
    function FL(k, r, L)
        return k*r*sphericalbesselj(L, k*r)
    end 

    # Spherical Hankel functions
    function Hminus(k, r, L)
        return complex(GL(k, r, L), -FL(k, r, L))
    end
    function Hplus(k, r, L)
        return complex(GL(k, r, L), FL(k, r, L))
    end

    # Derivatives
    enzR_Hminusprime(k, r, L) =
        complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])
    enzR_Hplusprime(k, r, L) =
        complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])

    enzF_Hminusprime(k, r, L) =
        complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])
    enzF_Hplusprime(k, r, L) =
        complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])

    function enzSL_0f0(U, L, μ, k, r, Ecm)
        dr = r[2] - r[1]
        len = size(r)[1]-1
        ur1, ur2, ur3 = 0.0, 0.0, 0.0
        ui1, ui2, ui3 = 0.0, 0.0, 0.0
        dur1, dur2, dur3 = 0.0, 0.0, 0.0
        dui1, dui2, dui3 = 0.0, 0.0, 0.0
        a = r[end-2]
        ur2 = 1e-6
        ui1 = 1e-12  # ideally these are all always Float32, or all always Float64
        ui2 = 1e-6
        for i in 3:len
            vreal = Ecm - U[i,1]
            vimag = -U[i,2]
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm -U[i-1,1]
            vimag = -U[i-1,2]
            wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm - U[i+1,1]
            vimag = -U[i+1,2]
            wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
            
            ur3 = real.(uval)
            dur3 = 0.5*(ur3-ur1)/dr
            ui3 = imag.(uval)
            dui3 = 0.5*(ui3-ui1)/dr
    
            ur1, ur2 = ur2, ur3
            dur1, dur2 = dur2, dur3
            ui1, ui2 = ui2, ui3
            dui1, dui2 = dui2, dui3
        end
        ua = complex(ur2,ui2)
        dua = complex(dur3,dui3)
        
        RL = ua / dua
        # SLtop = Hminus(k, a, L) - RL*enzR_Hminusprime(k, a, L)
        # SLbot = Hplus(k, a, L) - RL*enzR_Hplusprime(k, a, L)
        SLtop = Hminus(k, a, L) - RL*enzF_Hminusprime(k, a, L)
        SLbot = Hplus(k, a, L) - RL*enzF_Hplusprime(k, a, L)
    
        SL = SLtop/SLbot
        return [real(SL), imag(SL)]
    end

    # Koning-Delaroche Potential
    function kd_params(A, Z, E)
        N = A - Z
        Ef = -11.23814 + 0.02646*A
        v1 = 59.3 - 21.0*(N-Z)/A - 0.024*A
        v2 = 0.007228 - (1.48e-6)*A
        v3 = 1.994e-5 - (2.e-8)*A
        v4 = 7.e-9
        Vo = v1*(1-v2*(E-Ef)+v3*(E-Ef)^2-v4*(E-Ef)^3)
        ro = (1.3039 - 0.4054/A^(1/3))*A^(1/3)
        ao = 0.6778 - (1.487e-4)*A
        w1 = 12.195 + 0.0167*A
        w2 = 73.55 + 0.0795*A
        Wro = w1*(E-Ef)^2/((E-Ef)^2+w2^2)
        d1 = 16 - 16*(N-Z)/A
        d2 = 0.0180 + 0.003802/(1+exp((A-156.)/8.))
        d3 = 11.5
        rw = (1.3424 - 0.01585*A^(1/3))*A^(1/3)
        aw = 0.5446 - (1.656e-4)*A
        Wso = d1*(E-Ef)^2*exp(-d2*(E-Ef))/((E-Ef)^2+d3^2)
        return Vo, ro, ao, Wro, Wso, rw, aw
    end
    function kd_pots(A, Z, E, r)
        Vo, ro, ao, Wro, Wso, rw, aw = kd_params(A, Z, E)
        Vr = -Vo ./(1 .+ exp.(-(ro.-r)./ao))
        W = -Wro ./(1 .+ exp.(-(ro.-r)./ao))
        Ws = -4 .* Wso .* exp.(-(rw.-r)./aw) ./(1 .+exp.(-(rw.-r)./aw)).^2
        return Vr, W, Ws
    end
end

# Set up particular scattering problem
A = 65
Z = 29
N = A - Z
E = 10
L = 30
Ecm = 9.848393154293218
μ = 925.3211722114523
k = 0.6841596644044445
r = Vector(LinRange(0, 20, 2000))
dr = r[2] - r[1]
const global ħ = 197.3269804

# General a potential from K-D
Vreal, Wv, Ws = kd_pots(A, Z, E, r);
U = Float32.(hcat(Vreal, Wv + Ws);)

@btime reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), U)[1], 2, :)
@btime reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), Float64.(U))[1], 2, :)  # all the rest is Float64

This code is very fast:

288.222 μs (65 allocations: 65.91 KiB)
308.659 μs (70 allocations: 159.58 KiB)

@mcabbott’s version with set_runtime_activity(Reverse) applied to the neural network version is slower:

6.788 ms (599 allocations: 6.38 MiB)

The new NN version is the derivative with respect to 1202 parameters, while the older “U” version is with respect to 8000 constant values. It’s not apples and oranges, but I’d like to know if there’s a better way to do this than invoke set_runtime_activity which may be faster with Flux.

Effectively, I’m asking how to apply a) below to this problem:

ERROR: LoadError: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
 a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
 b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.