Speeding up Zygote autodiff for numerical loop

I’m using Zygote to auto-differentiate the (small) output of a numerical loop. It is quite slow, probably due to the naive way I’ve implemented it. I’m interested in advice on (a) vectorizing the loop to get better performance or (b) using foldl to achieve the same. I’m also open to using other packages if something is better suited to this.

# Try to speed up the core loop differentiation with either vectorization or a fold

begin # Packages
    using SpecialFunctions
    using BenchmarkTools
    using Zygote
    using SphericalHarmonics
end

begin # Functions
    function ZygoteSL(U, L, μ, k, r, Ecm)
        dr = r[2] - r[1]
        len = size(r)[1]-1
        ur = Zygote.bufferfrom(zeros(3))
        ui = Zygote.bufferfrom(zeros(3))
        dur = Zygote.bufferfrom(zeros(3))
        dui = Zygote.bufferfrom(zeros(3))
        a = r[end-2]
        ur[2] = 1e-6
        ui[1] = 1e-12
        ui[2] = 1e-6
        for i in 3:len
            vreal = Ecm - U[i,1]
            vimag = -U[i,2]
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm -U[i-1,1]
            vimag = -U[i-1,2]
            wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm - U[i+1,1]
            vimag = -U[i+1,2]
            wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            uval = (2*complex(ur[2],ui[2])-complex(ur[1],ui[1])-(dr^2/12.)*(10*w*complex(ur[2],ui[2])+wmo*complex(ur[1],ui[1])))/(1+(dr^2/12)*wpo)
            
            ur[3] = real.(uval)
            dur[3] = 0.5*(ur[3]-ur[1])/dr
            ui[3] = imag.(uval)
            dui[3] = 0.5*(ui[3]-ui[1])/dr

            ur[1:2] = ur[2:3]
            dur[1:2] = dur[2:3]
            ui[1:2] = ui[2:3]
            dui[1:2] = dui[2:3]
        end
        ua = complex(ur[2],ui[2])
        dua = complex(dur[3],dui[3])
        
        RL = ua / dua
        SLtop = Hminus(k, a, L) - RL*Hminusprime(k, a, L)
        SLbot = Hplus(k, a, L) - RL*Hplusprime(k, a, L)

        SL = SLtop/SLbot
        return [copy(real(SL)), copy(imag(SL))]
    end

    # Coulomb funcs
    function GL(k, r, L)
        return -k*r*sphericalbessely(L, k*r)
    end
    function FL(k, r, L)
        return k*r*sphericalbesselj(L, k*r)
    end 

    # Spherical Hankel functions
    function Hminus(k, r, L)
        return complex(GL(k, r, L), -FL(k, r, L))
    end
    function Hplus(k, r, L)
        return complex(GL(k, r, L), FL(k, r, L))
    end

    # Derivatives
    function  Hminusprime(k, r, L)
        return complex(Zygote.gradient(x -> GL(k, x, L), r)[1], -Zygote.gradient(x -> FL(k, x, L), r)[1])
    end
    function Hplusprime(k, r, L)
        return complex(Zygote.gradient(x -> GL(k, x, L), r)[1], Zygote.gradient(x -> FL(k, x, L), r)[1])
    end

    # Koning-Delaroche Potential
    function kd_params(A, Z, E)
        N = A - Z
        Ef = -11.23814 + 0.02646*A
        v1 = 59.3 - 21.0*(N-Z)/A - 0.024*A
        v2 = 0.007228 - (1.48e-6)*A
        v3 = 1.994e-5 - (2.e-8)*A
        v4 = 7.e-9
        Vo = v1*(1-v2*(E-Ef)+v3*(E-Ef)^2-v4*(E-Ef)^3)
        ro = (1.3039 - 0.4054/A^(1/3))*A^(1/3)
        ao = 0.6778 - (1.487e-4)*A
        w1 = 12.195 + 0.0167*A
        w2 = 73.55 + 0.0795*A
        Wro = w1*(E-Ef)^2/((E-Ef)^2+w2^2)
        d1 = 16 - 16*(N-Z)/A
        d2 = 0.0180 + 0.003802/(1+exp((A-156.)/8.))
        d3 = 11.5
        rw = (1.3424 - 0.01585*A^(1/3))*A^(1/3)
        aw = 0.5446 - (1.656e-4)*A
        Wso = d1*(E-Ef)^2*exp(-d2*(E-Ef))/((E-Ef)^2+d3^2)
        return Vo, ro, ao, Wro, Wso, rw, aw
    end
    function kd_pots(A, Z, E, r)
        Vo, ro, ao, Wro, Wso, rw, aw = kd_params(A, Z, E)
        Vr = -Vo ./(1 .+ exp.(-(ro.-r)./ao))
        W = -Wro ./(1 .+ exp.(-(ro.-r)./ao))
        Ws = -4 .* Wso .* exp.(-(rw.-r)./aw) ./(1 .+exp.(-(rw.-r)./aw)).^2
        return Vr, W, Ws
    end
end

# Set up particular scattering problem
A = 65
Z = 29
N = A - Z
E = 10
L = 14
Ecm = 9.848393154293218
μ = 925.3211722114523
k = 0.6841596644044445
r = Vector(LinRange(0, 20, 2000))
dr = r[2] - r[1]
const global ħ = 197.3269804

# General a potential from K-D
Vreal, Wv, Ws = kd_pots(A, Z, E, r);
U = Float32.(hcat(Vreal, Wv + Ws);)

begin # Zygote test
    fz(U) = ZygoteSL(U, L, μ, k, r, Ecm)

    # Try to calculate gradient
    @btime global fz_val = fz(U)
    @btime global dfz_val = Zygote.jacobian(U -> fz(U), U)
end

This gives me the following output:

401.193 μs (8223 allocations: 632.94 KiB)
326.888 ms (2576395 allocations: 210.58 MiB)

The differentiation is 1000x slower than the actual calculation!

Note: I’m on Julia 1.9.4 and Zygote v0.6.72 for compatibility reasons (maybe worth a separate post for those, as I’d love to work with newer versions).

Is Hminusprime_zyg == Hminusprime?

Indexing like this is Zygote’s worst nightmare, every U[i,1] allocates dU = zero(U) and then they are all added dU + dUprev:

        for i in 3:len  # that's 3:1999
            vreal = Ecm - U[i,1]  # U is a 2000×2 Matrix{Float32}
            vimag = -U[i,2]
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm -U[i-1,1]

There are things you can do… perhaps I’d start by replacing ur = Zygote.bufferfrom(zeros(3)) with ur1, ur2, ur3 = 0f0, 0f0, 0f0.

It’s not quite a loop over eachrow(U) but indexing that might be more efficient?

 Urows = eachrow(U)
 for i in 3:len 
     Ui1, Ui2 = Urows[i]
     Uim1, Uim2 = Urows[i-1]
     ri = r[i]

     vreal = Ecm - Ui1
     vimag = -Ui2
     w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/ri^2
     vreal = Ecm - Uim1

I get a similar time from ForwardDiff (over Zygote) changing only ur = zeros(eltype(U), 3) etc, although the answers look different?

I get errors from Enzyme, but perhaps teaching it enough about sphericalbessely would be quicker than fighting Zygote.

julia> Enzyme.jacobian(Reverse, U -> enzSL(U, L, μ, k, r, Ecm), U)
ERROR: 
No augmented forward pass found for zbesy_
 at context:   call void @zbesy_(i8* noundef nonnull %22, i8* noundef nonnull %25, i8* noundef nonnull %28, i8* noundef nonnull %31, i8* noundef nonnull %34, i8* noundef nonnull %4, i8* noundef nonnull %7, i8* noundef nonnull %10, i8* noundef nonnull %16, i8* noundef nonnull %19, i8* noundef nonnull %13) #27 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !44


Stacktrace:
  [1] _bessely
    @ ~/.julia/packages/SpecialFunctions/npKKV/src/bessel.jl:283
  [2] bessely
    @ ~/.julia/packages/SpecialFunctions/npKKV/src/bessel.jl:429
2 Likes

Yes, my mistake; Hminusprime_zyg == Hminusprime, I edited out that typo.

Interestingly, I got a reasonable speedup from your change for ui1,2,3 = 0f0… etc, a factor of 5 for the function evaluation time and 1.25 for the gradient. However, the eachrow version performed much worse for AD:

    function ZygoteSL_0f0(U, L, μ, k, r, Ecm)
        dr = r[2] - r[1]
        len = size(r)[1]-1
        ur = Zygote.bufferfrom(zeros(3))
        ur1, ur2, ur3 = 0f0, 0f0, 0f0
        ui1, ui2, ui3 = 0f0, 0f0, 0f0
        dur1, dur2, dur3 = 0f0, 0f0, 0f0
        dui1, dui2, dui3 = 0f0, 0f0, 0f0
        a = r[end-2]
        ur2 = 1e-6
        ui1 = 1e-12
        ui2 = 1e-6
        for i in 3:len
            vreal = Ecm - U[i,1]
            vimag = -U[i,2]
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm -U[i-1,1]
            vimag = -U[i-1,2]
            wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm - U[i+1,1]
            vimag = -U[i+1,2]
            wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12.)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
            
            ur3 = real.(uval)
            dur3 = 0.5*(ur3-ur1)/dr
            ui3 = imag.(uval)
            dui3 = 0.5*(ui3-ui1)/dr

            ur1, ur2 = ur2, ur3
            dur1, dur2 = dur2, dur3
            ui1, ui2 = ui2, ui3
            dui1, dui2 = dui2, dui3
        end
        ua = complex(ur2,ui2)
        dua = complex(dur3,dui3)
        
        RL = ua / dua
        SLtop = Hminus(k, a, L) - RL*Hminusprime(k, a, L)
        SLbot = Hplus(k, a, L) - RL*Hplusprime(k, a, L)

        SL = SLtop/SLbot
        return [copy(real(SL)), copy(imag(SL))]
    end

    function ZygoteSL_er(U, L, μ, k, r, Ecm)
        dr = r[2] - r[1]
        len = size(r)[1]-1
        ur = Zygote.bufferfrom(zeros(3))
        ur1, ur2, ur3 = 0f0, 0f0, 0f0
        ui1, ui2, ui3 = 0f0, 0f0, 0f0
        dur1, dur2, dur3 = 0f0, 0f0, 0f0
        dui1, dui2, dui3 = 0f0, 0f0, 0f0
        a = r[end-2]
        ur2 = 1e-6
        ui1 = 1e-12
        ui2 = 1e-6
        Urows = eachrow(U)
        for i in 3:len
            Ui1, Ui2 = Urows[i]
            Uim1, Uim2 = Urows[i-1]
            Uip1, Uip2 = Urows[i+1]
            ri = r[i]

            vreal = Ecm - Ui1
            vimag = -Uim2
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/ri^2
            vreal = Ecm -Uim1
            vimag = -Uim2
            wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/ri^2
            vreal = Ecm - Uip1
            vimag = -Uip2
            wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/ri^2
            uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12.)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
            
            ur3 = real.(uval)
            dur3 = 0.5*(ur3-ur1)/dr
            ui3 = imag.(uval)
            dui3 = 0.5*(ui3-ui1)/dr

            ur1, ur2 = ur2, ur3
            dur1, dur2 = dur2, dur3
            ui1, ui2 = ui2, ui3
            dui1, dui2 = dui2, dui3
        end
        ua = complex(ur2,ui2)
        dua = complex(dur3,dui3)
        
        RL = ua / dua
        SLtop = Hminus(k, a, L) - RL*Hminusprime(k, a, L)
        SLbot = Hplus(k, a, L) - RL*Hplusprime(k, a, L)

        SL = SLtop/SLbot
        return [copy(real(SL)), copy(imag(SL))]
    end

begin # Zygote test: 0f0
    println("\n0f0 version:")
    fz_0f0(U) = ZygoteSL_0f0(U, L, μ, k, r, Ecm)

    # Try to calculate gradient
    println("Func eval:")
    @btime global fz0_val = fz_0f0(U)
    println("Gradient eval:")
    @btime global dfz0_val = Zygote.jacobian(U -> fz_0f0(U), U)
end

begin # Zygote test: Eachrow
    println("\nEachrow version:")
    fz_er(U) = ZygoteSL_er(U, L, μ, k, r, Ecm)

    # Try to calculate gradient
    println("Func eval:")
    @btime global fzer_val = fz_er(U)
    println("Gradient eval:")
    @btime global dfzer_val = Zygote.jacobian(U -> fz_er(U), U)
end

Results:

OG:
Func eval:
  384.465 μs (8223 allocations: 632.94 KiB)
Gradient eval:
  318.956 ms (2576397 allocations: 210.58 MiB)

0f0 version:
Func eval:
  78.108 μs (228 allocations: 8.52 KiB)
Gradient eval:
  251.755 ms (2014924 allocations: 193.54 MiB)

Eachrow version:
Func eval:
  87.975 μs (228 allocations: 8.52 KiB)
Gradient eval:
  1.476 s (3335835 allocations: 1.19 GiB)

It may be worth writing a custom rule for Enzyme if it can do a much better job, but I’m still hoping to find a better method than the naive for loop.

If I’m not mistaken, those gradients are actually scalar derivatives, so the code will be much quicker if you replace them with ForwardDiff.derivative

2 Likes

Ok, pity about the eachrow, sorry!

I tried replacing the Jacobi functions with atan, and Enzyme is then 50x faster (and 1/50 the memory), as it’s happy to handle loops with indexing.
(Edit: doing this with the ur1, ur2, ur3 = 0.0, 0.0, 0.0 code, it’s 400x faster, 248 μs from 100 ms.)

Somehow on first reading, I missed that these functions are outside the main loop. So I doubt they are much of the time.

Zygote has a (slightly confusingly named) function forwarddiff(f, x) == f(x) which doesn’t perform a derivative itself, but does map any Zygote differentiation to ForwardDiff instead. It’s for exactly such things, patching one function call that ForwardDiff can handle and Zygote cannot. It would be useful to have such a function for Enzyme too, to patch in one function.

What Enzyme does have is import_frule & import_rrule which may be enough to make it work here. Edit, it’s not quite just Enzyme.@import_rrule(typeof(sphericalbessely), Real, Real) as that’s not where the CR rule is, rrule(sphericalbessely, 1, 2) === nothing # i.e there is no rule for this type.

Ah good point. Zygote over ForwardDiff will I think be transformed (by a Zygote rule) into ForwardDiff over ForwardDiff, which should be efficient.

1 Like

There is something with similar potential in DifferentiationInterface.jl, namely DifferentiateWith. It doesn’t yet have Enzyme rules, but PRs are welcome.

2 Likes

So I just released Enzyme 0.13.20 which should add support for complex-valued bessely directly (we already supported the real ones).

It’ll probably take the julia package servers ~1hour or so to propagate, but try it out and see if that works for you?

2 Likes

Since I have it here’s my attempt, with the error:

using Enzyme  # Enzyme v0.13.20
 
enzR_Hminusprime(k, r, L) =
    complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])
enzR_Hplusprime(k, r, L) =
    complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])

enzF_Hminusprime(k, r, L) =
    complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])
enzF_Hplusprime(k, r, L) =
    complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])

function enzSL_0f0(U, L, μ, k, r, Ecm)
    dr = r[2] - r[1]
    len = size(r)[1]-1
    # ur = Zygote.bufferfrom(zeros(3))  # not used
    ur1, ur2, ur3 = 0.0, 0.0, 0.0
    ui1, ui2, ui3 = 0.0, 0.0, 0.0
    dur1, dur2, dur3 = 0.0, 0.0, 0.0
    dui1, dui2, dui3 = 0.0, 0.0, 0.0
    a = r[end-2]
    ur2 = 1e-6
    ui1 = 1e-12  # ideally these are all always Float32, or all always Float64
    ui2 = 1e-6
    for i in 3:len
        vreal = Ecm - U[i,1]
        vimag = -U[i,2]
        w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
        vreal = Ecm -U[i-1,1]
        vimag = -U[i-1,2]
        wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
        vreal = Ecm - U[i+1,1]
        vimag = -U[i+1,2]
        wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
        uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
        
        ur3 = real.(uval)
        dur3 = 0.5*(ur3-ur1)/dr
        ui3 = imag.(uval)
        dui3 = 0.5*(ui3-ui1)/dr

        ur1, ur2 = ur2, ur3
        dur1, dur2 = dur2, dur3
        ui1, ui2 = ui2, ui3
        dui1, dui2 = dui2, dui3
    end
    ua = complex(ur2,ui2)
    dua = complex(dur3,dui3)
    
    RL = ua / dua
    SLtop = Hminus(k, a, L) - RL*enzR_Hminusprime(k, a, L)
    SLbot = Hplus(k, a, L) - RL*enzR_Hplusprime(k, a, L)
    # SLtop = Hminus(k, a, L) - RL*enzF_Hminusprime(k, a, L)
    # SLbot = Hplus(k, a, L) - RL*enzF_Hplusprime(k, a, L)

    SL = SLtop/SLbot
    return [real(SL), imag(SL)]  # no need for copy
end

Zygote.jacobian(U -> ZygoteSL_0f0(U, L, μ, k, r, Ecm), U)[1]

reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), U)[1], 2, :)
reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), Float64.(U))[1], 2, :)  # all the rest is Float64


#= 

(jl_FVANqM) pkg> st Enzyme
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_FVANqM/Project.toml`
  [7da242da] Enzyme v0.13.20 `https://github.com/EnzymeAD/Enzyme.jl.git#main`

julia> reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), Float64.(U))[1], 2, :)
ERROR: 
No create nofree of empty function (zbesy_) zbesy_)
 at context:   %207 = call [1 x [1 x double]] @diffejulia__5_80273_inner_5wrap({ double, i64 } %.fca.1.insert79, double %71, double noundef 1.000000e+00) #49, !dbg !296 (diffejulia__5_80273_inner_5wrap)

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201
 [2] enzyme_call
   @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747
 [3] CombinedAdjointThunk
   @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4619
 [4] autodiff_deferred
   @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:785
 [5] autodiff
   @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:524
 [6] macro expansion
   @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:326
 [7] gradient
   @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:258
 [8] enzR_Hminusprime
   @ ./REPL[7]:1
 [9] enzSL_0f0
   @ ./REPL[45]:39


Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201 [inlined]
  [2] enzyme_call
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747 [inlined]
  [3] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4619 [inlined]
  [4] autodiff_deferred
    @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:785 [inlined]
  [5] autodiff
    @ ~/.julia/packages/Enzyme/aViNX/src/Enzyme.jl:524 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:326 [inlined]
  [7] gradient
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:258 [inlined]
  [8] enzR_Hminusprime
    @ ./REPL[7]:1 [inlined]
  [9] enzSL_0f0
    @ ./REPL[45]:39 [inlined]
 [10] augmented_julia_enzSL_0f0_80110wrap
    @ ./REPL[45]:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747 [inlined]
 [13] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4683 [inlined]
 [14] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(enzSL_0f0), df::Nothing, primal_1::Matrix{…}, shadow_1_1::Matrix{…}, primal_2::Int64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Nothing, primal_4::Float64, shadow_4_1::Nothing, primal_5::Vector{…}, shadow_5_1::Nothing, primal_6::Float64, shadow_6_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aViNX/src/rules/jitrules.jl:480
 [15] #41
    @ ./REPL[48]:1 [inlined]
 [16] augmented_julia__41_81432wrap
    @ ./REPL[48]:0
 [17] macro expansion
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:5201 [inlined]
 [18] enzyme_call
    @ ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4747 [inlined]
 [19] (::Enzyme.Compiler.AugmentedForwardThunk{…})(fn::Const{…}, args::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aViNX/src/compiler.jl:4683
 [20] #130
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:928 [inlined]
 [21] ntuple
    @ ./ntuple.jl:49 [inlined]
 [22] jacobian(mode::ReverseMode{…}, f::var"#41#42", x::Matrix{…}; n_outs::Val{…}, chunk::Nothing)
    @ Enzyme ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:924
 [23] jacobian
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:841 [inlined]
 [24] #jacobian#129
    @ ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:856 [inlined]
 [25] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#41#42", x::Matrix{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/aViNX/src/sugar.jl:841
 [26] top-level scope
    @ REPL[48]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> VERSION
v"1.11.0"

=#
1 Like

Oh lol Enzyme proved it didn’t need to compute the derivative and I didn’t add that case fully. I’ll fix this after lunch

3 Likes

Okay fixed Nofree for math methods by wsmoses · Pull Request #2184 · EnzymeAD/Enzyme.jl · GitHub (will release in a bit).

Perf is good, though the results look weird for the second row… like either tiny or giant (both Enzyme and Zygote have)? For the normal sized outputs (like the end of the first row) things look right.

That said @mcabbott I’m not sure if I’m computing the same thing in your code as above (I didn’t see ZygoteSL_0f0 defined in your snippet). Derivatives work though and its fast, which is nice.

julia> begin
         efn(U) = enzSL_0f0(U, L, μ, k, r, Ecm)
         @btime reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), U)[1], 2, :)
       end
  244.292 μs (65 allocations: 66.09 KiB)
2×4000 reshape(PermutedDimsArray(::Array{Float32, 3}, (3, 1, 2)), 2, 4000) with eltype Float32:
 0.0   1.1307f-27   3.55407f-22   2.75509f-22  -4.72553f-22  -1.57256f-20   7.95213f-21  …  0.000476027  0.000479498  0.000482988  0.000486495  0.000469456  0.000245895  2.0565f-5
 0.0  -1.2638f-27  -2.09308f-20  -1.62252f-20   2.78289f-20   9.25912f-19  -4.68337f-19     9.98074f-7   1.00535f-6   1.01267f-6   1.02002f-6   9.84297f-7   5.15561f-7   4.31182f-8

julia> begin # Zygote test
          fz(U) = ZygoteSL(U, L, μ, k, r, Ecm)

          # Try to calculate gradient
          @btime global fz_val = fz(U)
          @btime global dfz_val = Zygote.jacobian(fz, U)
       end
  245.417 μs (8223 allocations: 632.94 KiB)
  236.153 ms (2582180 allocations: 211.74 MiB)
([0.0 7.863046184058266e-28 … 0.0002458945964463055 2.056503581115976e-5; 0.0 1.478219661312e12 … -3.711214846463957e20 4.311818457836125e-8],)

julia> begin # Zygote test
          fz(U) = ZygoteSL(U, L, μ, k, r, Ecm)

          # Try to calculate gradient
          @btime global fz_val = fz(U)
          @btime global dfz_val = Zygote.jacobian(fz, U)[1]
       end
  245.166 μs (8223 allocations: 632.94 KiB)
  242.611 ms (2582180 allocations: 211.74 MiB)
2×4000 Matrix{Float64}:
 0.0  7.86305e-28  7.1804e-21  5.56613e-21  -9.54683e-21  -3.1764e-19  1.60665e-19  1.27534e-19  …   0.000479498   0.000482988   0.000486495   0.000469456   0.000245895  2.0565e-5
 0.0  1.47822e12   1.09202e20  8.46516e19   -1.45191e20   -4.83074e21  2.44345e21   1.93958e21      -1.75584e22   -1.32165e22   -8.84292e21   -4.43744e21   -3.71121e20   4.31182e-8
1 Like

Thanks so much for adding this! I am getting a new error on the latest Enzyme, however:

ERROR: LoadError: LLVM error: Duplicate definition of symbol 'libname_zbesy__20710'
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/LLVM/wMjUU/src/executionengine/utils.jl:28 [inlined]
  [2] add!
    @ ~/.julia/packages/LLVM/wMjUU/src/orc.jl:433 [inlined]
  [3] add!(mod::LLVM.Module)
    @ Enzyme.Compiler.JIT ~/.julia/packages/Enzyme/6C71q/src/compiler/orcv2.jl:264
  [4] _link(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, adjoint_name::String, primal_name::Union{…}, TapeType::Any, prepost::String)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5239
  [5] cached_compilation
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5325 [inlined]
  [6] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5434
  [7] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5601
  [8] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(enzSL_0f0), df::Nothing, primal_1::Matrix{…}, shadow_1_1::Matrix{…}, primal_2::Int64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Nothing, primal_4::Float64, shadow_4_1::Nothing, primal_5::Vector{…}, shadow_5_1::Nothing, primal_6::Float64, shadow_6_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/rules/jitrules.jl:465
  [9] #35
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:134 [inlined]
 [10] augmented_julia__35_12983wrap
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5204 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:4750 [inlined]
 [13] (::Enzyme.Compiler.AugmentedForwardThunk{…})(fn::Const{…}, args::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:4686
 [14] #130
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:928 [inlined]
 [15] ntuple
    @ ./ntuple.jl:49 [inlined]
 [16] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#35#36", x::Matrix{Float32}; n_outs::Val{(2,)}, chunk::Nothing)
    @ Enzyme ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:924
 [17] jacobian
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:841 [inlined]
 [18] #jacobian#129
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:856 [inlined]
 [19] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#35#36", x::Matrix{Float32})
    @ Enzyme ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:841
 [20] top-level scope
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:134
 [21] include(fname::String)
    @ Main ./sysimg.jl:38
 [22] top-level scope
    @ REPL[3]:1
in expression starting at /vast/home/daningburg/nuclear-diffprog/MWEs/coreloop_enz.jl:134
Some type information was truncated. Use `show(err)` to see complete types.

For completeness, here is the full code for Enzyme testing I’m using:

# Test coreloop diff with Enzyme

begin # Packages
    using SpecialFunctions
    using BenchmarkTools
    using SphericalHarmonics
    using Enzyme
end

begin # Functions
    # Coulomb funcs
    function GL(k, r, L)
        return -k*r*sphericalbessely(L, k*r)
    end
    function FL(k, r, L)
        return k*r*sphericalbesselj(L, k*r)
    end 

    # Spherical Hankel functions
    function Hminus(k, r, L)
        return complex(GL(k, r, L), -FL(k, r, L))
    end
    function Hplus(k, r, L)
        return complex(GL(k, r, L), FL(k, r, L))
    end

    # Derivatives
    enzR_Hminusprime(k, r, L) =
        complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])
    enzR_Hplusprime(k, r, L) =
        complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])

    enzF_Hminusprime(k, r, L) =
        complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])
    enzF_Hplusprime(k, r, L) =
        complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])

    function enzSL_0f0(U, L, μ, k, r, Ecm)
        dr = r[2] - r[1]
        len = size(r)[1]-1
        ur1, ur2, ur3 = 0.0, 0.0, 0.0
        ui1, ui2, ui3 = 0.0, 0.0, 0.0
        dur1, dur2, dur3 = 0.0, 0.0, 0.0
        dui1, dui2, dui3 = 0.0, 0.0, 0.0
        a = r[end-2]
        ur2 = 1e-6
        ui1 = 1e-12  # ideally these are all always Float32, or all always Float64
        ui2 = 1e-6
        for i in 3:len
            vreal = Ecm - U[i,1]
            vimag = -U[i,2]
            w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm -U[i-1,1]
            vimag = -U[i-1,2]
            wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            vreal = Ecm - U[i+1,1]
            vimag = -U[i+1,2]
            wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
            uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
            
            ur3 = real.(uval)
            dur3 = 0.5*(ur3-ur1)/dr
            ui3 = imag.(uval)
            dui3 = 0.5*(ui3-ui1)/dr
    
            ur1, ur2 = ur2, ur3
            dur1, dur2 = dur2, dur3
            ui1, ui2 = ui2, ui3
            dui1, dui2 = dui2, dui3
        end
        ua = complex(ur2,ui2)
        dua = complex(dur3,dui3)
        
        RL = ua / dua
        SLtop = Hminus(k, a, L) - RL*enzR_Hminusprime(k, a, L)
        SLbot = Hplus(k, a, L) - RL*enzR_Hplusprime(k, a, L)
        # SLtop = Hminus(k, a, L) - RL*enzF_Hminusprime(k, a, L)
        # SLbot = Hplus(k, a, L) - RL*enzF_Hplusprime(k, a, L)
    
        SL = SLtop/SLbot
        return [real(SL), imag(SL)]
    end

    # Koning-Delaroche Potential
    function kd_params(A, Z, E)
        N = A - Z
        Ef = -11.23814 + 0.02646*A
        v1 = 59.3 - 21.0*(N-Z)/A - 0.024*A
        v2 = 0.007228 - (1.48e-6)*A
        v3 = 1.994e-5 - (2.e-8)*A
        v4 = 7.e-9
        Vo = v1*(1-v2*(E-Ef)+v3*(E-Ef)^2-v4*(E-Ef)^3)
        ro = (1.3039 - 0.4054/A^(1/3))*A^(1/3)
        ao = 0.6778 - (1.487e-4)*A
        w1 = 12.195 + 0.0167*A
        w2 = 73.55 + 0.0795*A
        Wro = w1*(E-Ef)^2/((E-Ef)^2+w2^2)
        d1 = 16 - 16*(N-Z)/A
        d2 = 0.0180 + 0.003802/(1+exp((A-156.)/8.))
        d3 = 11.5
        rw = (1.3424 - 0.01585*A^(1/3))*A^(1/3)
        aw = 0.5446 - (1.656e-4)*A
        Wso = d1*(E-Ef)^2*exp(-d2*(E-Ef))/((E-Ef)^2+d3^2)
        return Vo, ro, ao, Wro, Wso, rw, aw
    end
    function kd_pots(A, Z, E, r)
        Vo, ro, ao, Wro, Wso, rw, aw = kd_params(A, Z, E)
        Vr = -Vo ./(1 .+ exp.(-(ro.-r)./ao))
        W = -Wro ./(1 .+ exp.(-(ro.-r)./ao))
        Ws = -4 .* Wso .* exp.(-(rw.-r)./aw) ./(1 .+exp.(-(rw.-r)./aw)).^2
        return Vr, W, Ws
    end
end

# Set up particular scattering problem
A = 65
Z = 29
N = A - Z
E = 10
L = 14
Ecm = 9.848393154293218
μ = 925.3211722114523
k = 0.6841596644044445
r = Vector(LinRange(0, 20, 2000))
dr = r[2] - r[1]
const global ħ = 197.3269804

# General a potential from K-D
Vreal, Wv, Ws = kd_pots(A, Z, E, r);
U = Float32.(hcat(Vreal, Wv + Ws);)

reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), U)[1], 2, :)
reshape(Enzyme.jacobian(Reverse, U -> enzSL_0f0(U, L, μ, k, r, Ecm), Float64.(U))[1], 2, :)  # all the rest is Float64

This is on Julia 1.11.1 and Enzyme v0.13.21.
Also thanks to @mcabbott for the assist!

Oh julia 1.11 and its changes to library loading.

Post the whole runnable code as an issue on enzyme.jl and will fix.

In the interim try 1.10?

1 Like

I’m actually getting the same error in 1.10.6:

ERROR: LoadError: LLVM error: Duplicate definition of symbol 'libname_zbesy__3764'
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/LLVM/wMjUU/src/executionengine/utils.jl:28 [inlined]
  [2] add!
    @ ~/.julia/packages/LLVM/wMjUU/src/orc.jl:433 [inlined]
  [3] add!(mod::LLVM.Module)
    @ Enzyme.Compiler.JIT ~/.julia/packages/Enzyme/6C71q/src/compiler/orcv2.jl:264
  [4] _link(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, adjoint_name::String, primal_name::Union{…}, TapeType::Any, prepost::String)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5239
  [5] cached_compilation
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5325 [inlined]
  [6] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5434
  [7] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5601
  [8] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::typeof(enzSL_0f0), df::Nothing, primal_1::Matrix{…}, shadow_1_1::Matrix{…}, primal_2::Int64, shadow_2_1::Nothing, primal_3::Float64, shadow_3_1::Nothing, primal_4::Float64, shadow_4_1::Nothing, primal_5::Vector{…}, shadow_5_1::Nothing, primal_6::Float64, shadow_6_1::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/rules/jitrules.jl:465
  [9] #17
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:132 [inlined]
 [10] augmented_julia__17_2308wrap
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:0
 [11] macro expansion
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:5204 [inlined]
 [12] enzyme_call
    @ ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:4750 [inlined]
 [13] (::Enzyme.Compiler.AugmentedForwardThunk{…})(fn::Const{…}, args::Duplicated{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/6C71q/src/compiler.jl:4686
 [14] #130
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:928 [inlined]
 [15] ntuple
    @ ./ntuple.jl:49 [inlined]
 [16] jacobian(mode::ReverseMode{…}, f::var"#17#18", x::Matrix{…}; n_outs::Val{…}, chunk::Nothing)
    @ Enzyme ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:924
 [17] jacobian
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:841 [inlined]
 [18] #jacobian#129
    @ ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:856 [inlined]
 [19] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#17#18", x::Matrix{Float32})
    @ Enzyme ~/.julia/packages/Enzyme/6C71q/src/sugar.jl:841
 [20] top-level scope
    @ ~/nuclear-diffprog/MWEs/coreloop_enz.jl:132
 [21] include(fname::String)
    @ Base.MainInclude ./client.jl:494
 [22] top-level scope
    @ REPL[3]:1
in expression starting at /vast/home/daningburg/nuclear-diffprog/MWEs/coreloop_enz.jl:132
Some type information was truncated. Use `show(err)` to see complete types.

It’s working due to your update (Enzyme v0.13.23), thanks!