I’m trying to use Enzyme to autodiff a function with respect to a neural network’s parameters. I’m running into this error:
No augmented forward pass found for jl_genericmemory_slice
at context: %93 = call nonnull "enzyme_type"="{[-1]:Pointer}" {} addrspace(10)* @jl_genericmemory_slice({} addrspace(10)* nonnull %76, i64 %92, i64 %40) #25, !dbg !234
Stacktrace:
[1] reshape
@ ./reshapedarray.jl:55
[2] reshape
@ ./reshapedarray.jl:121
[3] _getat
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:102
Stacktrace:
[1] reshape
@ ./reshapedarray.jl:55 [inlined]
[2] reshape
@ ./reshapedarray.jl:121 [inlined]
[3] _getat
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:102
[4] #57
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:97 [inlined]
[5] ExcludeWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:144 [inlined]
[6] CachedWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:195 [inlined]
[7] CachedWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:0 [inlined]
[8] augmented_julia_CachedWalk_44270_inner_9wrap
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:0
[9] macro expansion
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:5218 [inlined]
[10] enzyme_call
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4764 [inlined]
[11] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4700 [inlined]
[12] runtime_generic_augfwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, ModifiedBetween::Val{…}, RT::Val{…}, f::Functors.CachedWalk{…}, df::Functors.CachedWalk{…}, primal_1::Functors.var"#recurse#26"{…}, shadow_1_1::Functors.var"#recurse#26"{…}, primal_2::Matrix{…}, shadow_2_1::Nothing, primal_3::Int64, shadow_3_1::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/QVjE5/src/rules/jitrules.jl:480
[13] recurse
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:52 [inlined]
[14] #59
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:115 [inlined]
[15] map
@ ./tuple.jl:406 [inlined]
[16] map
@ ./namedtuple.jl:266 [inlined]
[17] _trainmap
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:114 [inlined]
[18] _Trainable_biwalk
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:110 [inlined]
[19] ExcludeWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:144 [inlined]
[20] CachedWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:195 [inlined]
[21] recurse
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:52 [inlined]
[22] #59
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:115 [inlined]
[23] map
@ ./tuple.jl:406 [inlined]
[24] _trainmap
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:114
[25] _Trainable_biwalk
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:110 [inlined]
[26] ExcludeWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:144 [inlined]
[27] CachedWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:195 [inlined]
[28] recurse
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:52 [inlined]
[29] #59
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:115 [inlined]
[30] map
@ ./tuple.jl:406 [inlined]
[31] map
@ ./namedtuple.jl:266 [inlined]
[32] _trainmap
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:114 [inlined]
[33] _Trainable_biwalk
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:110 [inlined]
[34] ExcludeWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:144 [inlined]
[35] CachedWalk
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:195 [inlined]
[36] execute
@ ~/.julia/packages/Functors/wOtRi/src/walks.jl:53 [inlined]
[37] #fmap#40
@ ~/.julia/packages/Functors/wOtRi/src/maps.jl:11 [inlined]
[38] fmap
@ ~/.julia/packages/Functors/wOtRi/src/maps.jl:3 [inlined]
[39] #_rebuild#56
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:96 [inlined]
[40] _rebuild
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:94 [inlined]
[41] Restructure
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:59 [inlined]
[42] Restructure
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:0 [inlined]
[43] augmented_julia_Restructure_43766_inner_1wrap
@ ~/.julia/packages/Optimisers/V8kHf/src/destructure.jl:0
[44] macro expansion
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:5218 [inlined]
[45] enzyme_call
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4764 [inlined]
[46] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4700 [inlined]
[47] runtime_generic_augfwd(activity::Type{Val{…}}, runtimeActivity::Val{false}, width::Val{1}, ModifiedBetween::Val{(true, true)}, RT::Val{@NamedTuple{…}}, f::Optimisers.Restructure{Chain{…}, @NamedTuple{…}}, df::Nothing, primal_1::Vector{Float64}, shadow_1_1::Vector{Float64})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/QVjE5/src/rules/jitrules.jl:480
[48] #73
@ ~/nuclear-diffprog/MWEs/coreloop_ez_nn.jl:137 [inlined]
[49] augmented_julia__73_42667wrap
@ ~/nuclear-diffprog/MWEs/coreloop_ez_nn.jl:0
[50] macro expansion
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:5218 [inlined]
[51] enzyme_call
@ ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4764 [inlined]
[52] (::Enzyme.Compiler.AugmentedForwardThunk{Ptr{Nothing}, Const{var"#73#74"}, DuplicatedNoNeed{Any}, Tuple{Duplicated{Vector{Float64}}}, 1, false, @NamedTuple{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}})(fn::Const{var"#73#74"}, args::Duplicated{Vector{Float64}})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/QVjE5/src/compiler.jl:4700
[53] #130
@ ~/.julia/packages/Enzyme/QVjE5/src/sugar.jl:928 [inlined]
[54] ntuple
@ ./ntuple.jl:49 [inlined]
[55] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#73#74", x::Vector{Float64}; n_outs::Val{(2,)}, chunk::Nothing)
@ Enzyme ~/.julia/packages/Enzyme/QVjE5/src/sugar.jl:924
[56] jacobian
@ ~/.julia/packages/Enzyme/QVjE5/src/sugar.jl:841 [inlined]
[57] #jacobian#129
@ ~/.julia/packages/Enzyme/QVjE5/src/sugar.jl:856 [inlined]
[58] jacobian(mode::ReverseMode{false, false, FFIABI, false, false}, f::var"#73#74", x::Vector{Float64})
@ Enzyme ~/.julia/packages/Enzyme/QVjE5/src/sugar.jl:841
[59] top-level scope
@ ~/nuclear-diffprog/MWEs/coreloop_ez_nn.jl:137
[60] include(fname::String)
@ Main ./sysimg.jl:38
in expression starting at /vast/home/daningburg/nuclear-diffprog/MWEs/coreloop_ez_nn.jl:137
Some type information was truncated. Use `show(err)` to see complete types.
I was able to recreate this error in my MWE:
# Test coreloop diff with Enzyme
begin # Packages
using SpecialFunctions
using BenchmarkTools
using SphericalHarmonics
using Enzyme
using Flux
end
begin # Functions
# Coulomb funcs
function GL(k, r, L)
return -k*r*sphericalbessely(L, k*r)
end
function FL(k, r, L)
return k*r*sphericalbesselj(L, k*r)
end
# Spherical Hankel functions
function Hminus(k, r, L)
return complex(GL(k, r, L), -FL(k, r, L))
end
function Hplus(k, r, L)
return complex(GL(k, r, L), FL(k, r, L))
end
# Derivatives
enzR_Hminusprime(k, r, L) =
complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])
enzR_Hplusprime(k, r, L) =
complex(Enzyme.gradient(Reverse, x -> GL(k, x, L), r)[1], Enzyme.gradient(Reverse, x -> FL(k, x, L), r)[1])
enzF_Hminusprime(k, r, L) =
complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], -Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])
enzF_Hplusprime(k, r, L) =
complex(Enzyme.gradient(Forward, x -> GL(k, x, L), r)[1], Enzyme.gradient(Forward, x -> FL(k, x, L), r)[1])
function enzSL_0f0(U, L, μ, k, r, Ecm)
dr = r[2] - r[1]
len = size(r)[1]-1
ur1, ur2, ur3 = 0.0, 0.0, 0.0
ui1, ui2, ui3 = 0.0, 0.0, 0.0
dur1, dur2, dur3 = 0.0, 0.0, 0.0
dui1, dui2, dui3 = 0.0, 0.0, 0.0
a = r[end-2]
ur2 = 1e-6
ui1 = 1e-12 # ideally these are all always Float32, or all always Float64
ui2 = 1e-6
for i in 3:len
vreal = Ecm - U[i,1]
vimag = -U[i,2]
w = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
vreal = Ecm -U[i-1,1]
vimag = -U[i-1,2]
wmo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
vreal = Ecm - U[i+1,1]
vimag = -U[i+1,2]
wpo = 2*μ/ħ^2*complex(vreal, vimag) - L*(L+1)/r[i]^2
uval = (2*complex(ur2,ui2)-complex(ur1,ui1)-(dr^2/12)*(10*w*complex(ur2,ui2)+wmo*complex(ur1,ui1)))/(1+(dr^2/12)*wpo)
ur3 = real.(uval)
dur3 = 0.5*(ur3-ur1)/dr
ui3 = imag.(uval)
dui3 = 0.5*(ui3-ui1)/dr
ur1, ur2 = ur2, ur3
dur1, dur2 = dur2, dur3
ui1, ui2 = ui2, ui3
dui1, dui2 = dui2, dui3
end
ua = complex(ur2,ui2)
dua = complex(dur3,dui3)
RL = ua / dua
# SLtop = Hminus(k, a, L) - RL*enzR_Hminusprime(k, a, L)
# SLbot = Hplus(k, a, L) - RL*enzR_Hplusprime(k, a, L)
SLtop = Hminus(k, a, L) - RL*enzF_Hminusprime(k, a, L)
SLbot = Hplus(k, a, L) - RL*enzF_Hplusprime(k, a, L)
SL = SLtop/SLbot
return [real(SL), imag(SL)]
end
function build_model(n_in, n_out, n_layers, n_nodes, act_fun=relu, last_fun=relu)
first_layer = Flux.Dense(n_in, n_nodes, act_fun)
# hidden_layers = [Flux.Dense(n_in => n_nodes, act_fun) for _ in 1:n_layers-1]
last_layer = Flux.Dense(n_nodes => n_out)
m = Chain(first_layer, Flux.Dense(n_nodes => n_nodes, act_fun), Flux.Dense(n_nodes => n_nodes, act_fun),
Flux.Dense(n_nodes => n_nodes, act_fun), Flux.Dense(n_nodes => n_nodes, act_fun), last_layer) |> f64
return m
end
function eval_model(m, x)
# x_eval = convert(Array{Float32}, normalize_to_existing(x, tx))
# x_eval = normalize_to_existing(x, tx)
# println("x_eval: " * string(x_eval))
X = m(x)
# Z = denormalize_data(M, ty)
return X
end
# Combine x with r for use by neural network
function combex(x, r)
xlen = size(x)[1]
rlen = size(r)[1]
X = zeros(eltype(x), xlen*rlen, size(x)[2]+1)
for i in 1:1:xlen
X[(i-1)*rlen+1:i*rlen, 1] = r
for j in (i-1)*rlen+1:1:i*rlen
X[j, 2:end] = x[i,:]
end
end
return X'
end
end
# Set up particular scattering problem
A = 65.
Z = 29.
N = A - Z
E = 10.
L = 30
Ecm = 9.848393154293218
μ = 925.3211722114523
k = 0.6841596644044445
r = Vector(LinRange(0, 20, 2000))
dr = r[2] - r[1]
const global ħ = 197.3269804
# Model
x = [A Z E]
X = combex(x, r)
m = build_model(4, 2, 4, 16)
params, re = Flux.destructure(m)
Enzyme.jacobian(Reverse, p -> enzSL_0f0(eval_model(re(p), X)', L, μ, k, r, Ecm), params)
Note: this is on Julia 1.11.1 and Enzyme v0.13.23.