Hi,
I’m using Flux and Enzyme for a project, and the entire architecture is built on top of these two libraries. So unless strictly necessary, I’d like to stay within the Flux/Enzyme ecosystem.
For both libraries i’m using the last stable version
The issue arises when I use my custom layer together with any native Flux layer and compute a loss that involves the derivative of the model output with respect to its inputs (physically, these represent atomic forces in my system).
Here is the code for my custom layer:
struct G1Layer
W_eta::Vector{Float32}
W_Fs::Vector{Float32}
cutoff::Float32
charge::Float32
end
Flux.@layer G1Layer trainable = (W_eta,W_Fs,)
function G1Layer(N_G1::Int, cutoff::Float32, charge::Float32; seed::Union{Int,Nothing}=nothing)
rng = seed === nothing ? Random.GLOBAL_RNG : MersenneTwister(seed)
# Avoid Fs too close to zero to prevent huge contributions
r_min = 0.1f0
W_Fs = range(r_min, cutoff, length=N_G1) .+ 0.01f0 .* rand(rng, Float32, N_G1)
# Compute average spacing and set eta proportional to 1/(spacing^2)
delta = diff(W_Fs)
avg_spacing = mean(delta)
eta_base = 1.0f0 / (avg_spacing^2)
W_eta = eta_base .* (0.8f0 .+ 0.4f0 .* rand(rng, Float32, N_G1))
return G1Layer(W_eta, W_Fs, cutoff, charge)
end
function (layer::G1Layer)(x::AbstractMatrix{Float32})
n_batch, n_neighbors = size(x)
n_features = size(layer.W_eta, 1)
output = zeros(Float32, n_features, n_batch)
@inbounds for b in 1:n_batch
for f in 1:n_features
s = 0f0
for n in 1:n_neighbors
dx = x[b, n] - layer.W_Fs[f]
s += fc(x[b, n], layer.cutoff) * exp(-layer.W_eta[f] * dx * dx)
end
output[f, b] = 0.1f0 * layer.charge * s
end
end
return output
end
function fc(
Rij::Float32, Rc::Float32
)
if Rij >= Rc
return zero(Float32)
end
ε = eps(Float32)
denom = 1 - (Rij / Rc)^2
if denom < ε
return zero(Float32)
end
arg = 1 - 1 / denom
return (exp(arg))
end
And here’s an example of a loss function similar to the one I use (not exactly the same, but it reproduces the problem — the issue seems related to taking second derivatives):
function losss(m , x , y)
e_L = mean((m(x) .- y) .^2)
f_L = Enzyme.gradient(Reverse , (mm,xx) -> mm(xx)[1], Const(m) , x)[2]
return e_L + norm(f_L)
end
When I compute the gradient of the loss with respect to the model parameters:
If the model is composed only of native Flux layers → works fine, i just get some warnings but with the obtained gradient it’s possible to trainthe model.
x = rand(Float32 , 5)
y = sum(rand(Float32 , 5) .* x')
model = Dense(5,1)
println("The output of the model is: ",model(x))
o = OptimiserChain(ClipNorm(1.0), Adam(0.1))
opt = Flux.setup(o, model)
grad = Enzyme.gradient(set_runtime_activity(Reverse) , (m , xx , yy)-> losss(m , xx ,yy) , model , Const(x) , Const(y))[1]
This is the output
e[33me[1m┌ e[22me[39me[33me[1mWarning: e[22me[39mUsing fallback BLAS replacements for (["sasum_64_"]), performance may be degraded
e[33me[1m└ e[22me[39me[90m@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\LMVya\src\compiler.jl:4351e[39m
freeing without malloc %15 = extractvalue { i8*, float*, float*, float*, i64 } %tapeArg, 0
freeing without malloc %19 = extractvalue { i8*, float*, float*, float*, i64 } %tapeArg, 1, !dbg !321
freeing without malloc %15 = extractvalue { i8*, float*, float*, float*, i64 } %tapeArg, 0
freeing without malloc %19 = extractvalue { i8*, float*, float*, float*, i64 } %tapeArg, 1, !dbg !321
If the model is composed only of my custom layer (e.g. Chain(MyLayer())) → works fine.
x = rand(Float32 , 1 , 5)
y = sum(rand(Float32 , 5) .* x')
model = G1Layer(2 , 5.0f0 , 2.0f0)
println("The output of the model is: ",model(x))
o = OptimiserChain(ClipNorm(1.0), Adam(0.1))
opt = Flux.setup(o, model)
grad = Enzyme.gradient(set_runtime_activity(Reverse) , (m , xx , yy)-> losss(m , xx ,yy) , model , Const(x) , Const(y))[1]
still i get this warning
e[33me[1m┌ e[22me[39me[33me[1mWarning: e[22me[39mUsing fallback BLAS replacements for (["sasum_64_"]), performance may be degraded
e[33me[1m└ e[22me[39me[90m@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\LMVya\src\compiler.jl:4351e[39m
But if I combine my layer and any native Flux layer in the same chain, like this
model = Chain(
G1Layer(2 , 5.0f0 , 2.0f0),
Dense(2 , 1))
println("The output of the model is: ",model(x))
o = OptimiserChain(ClipNorm(1.0), Adam(0.1))
opt = Flux.setup(o, model)
grad = Enzyme.gradient(set_runtime_activity(Reverse) , (m , xx , yy)-> losss(m , xx ,yy) , model , Const(x) , Const(y))[1]
Flux.update!(opt, model, grad)
If i do it on jupyter the kernel just crashes yielding this error message
The Kernel crashed while executing code in the current cell or a previous cell.
Please review the code in the cell(s) to identify a possible cause of the failure.
Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info.
View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
while doing it using a standard .jl file i get this monstrosity here
ERROR: LoadError: AssertionError: Enzyme internal error unsupported got(fname)
inst=LLVM.LoadInst(%utf8proc_toupper = load atomic ptr, ptr @jlplt_utf8proc_toupper_37464_got unordered, align 8, !dbg !84)
fname=
FT=LLVM.FunctionType(i32 (i32))
fn_got=LLVM.GlobalVariable("jlplt_utf8proc_toupper_37464_got")
init=define private i32 @jlplt_utf8proc_toupper_37464(i32 %0) #13 {
top:
%utf8proc_toupper.cached = load atomic ptr, ptr @ccall_utf8proc_toupper_37463 unordered, align 8
%is_cached = icmp ne ptr %utf8proc_toupper.cached, null
br i1 %is_cached, label %ccall, label %dlsym
dlsym: ; preds = %top
%utf8proc_toupper.found = call ptr @ijl_load_and_lookup(ptr inttoptr (i64 3 to ptr), ptr @_j_str_utf8proc_toupper_15, ptr @jl_libjulia_internal_handle)
store atomic ptr %utf8proc_toupper.found, ptr @ccall_utf8proc_toupper_37463 release, align 8
br label %ccall
ccall: ; preds = %dlsym, %top
%utf8proc_toupper = phi ptr [ %utf8proc_toupper.cached, %top ], [ %utf8proc_toupper.found, %dlsym ]
store atomic ptr %utf8proc_toupper, ptr @jlplt_utf8proc_toupper_37464_got release, align 8
%1 = musttail call i32 %utf8proc_toupper(i32 %0)
ret i32 %1
}
opv=@ccall_utf8proc_toupper_37463 = global ptr null
found= %utf8proc_toupper.found = call ptr @ijl_load_and_lookup(ptr inttoptr (i64 3 to ptr), ptr @_j_str_utf8proc_toupper_15, ptr @jl_libjulia_internal_handle)
fname=@_j_str_utf8proc_toupper_15 = private unnamed_addr constant [17 x i8] c"utf8proc_toupper\00", align 1
Stacktrace:
[1] check_ir!(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module)
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler\validation.jl:385
[2] check_ir!(interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing}, job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{Base.StackTraces.StackFrame}, Any}}, mod::LLVM.Module)
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler\validation.jl:212
[3] check_ir
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler\validation.jl:181 [inlined]
[4] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget{GPUCompiler.NativeCompilerTarget}, Enzyme.Compiler.EnzymeCompilerParams{Enzyme.Compiler.PrimalCompilerParams}})
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:4515
[5] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler C:\Users\Gianmarco\.julia\packages\GPUCompiler\Gp8bZ\src\driver.jl:67
[6] compile
@ C:\Users\Gianmarco\.julia\packages\GPUCompiler\Gp8bZ\src\driver.jl:55 [inlined]
[7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget{GPUCompiler.NativeCompilerTarget}, Enzyme.Compiler.EnzymeCompilerParams{Enzyme.Compiler.PrimalCompilerParams}}, postopt::Bool)
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:5956
[8] _thunk
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:5954 [inlined]
[9] cached_compilation
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:6011 [inlined]
[10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{<:Annotation}, A::Type{<:Annotation}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, edges::Vector{Any})
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:6127
[11] thunk_generator(world::UInt64, source::Union{LineNumberNode, Method}, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{N, Bool} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, StrongZero::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type, strongzero::Type)
@ Enzyme.Compiler C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\compiler.jl:6271
[12] autodiff
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\Enzyme.jl:502 [inlined]
[13] autodiff
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\Enzyme.jl:542 [inlined]
[14] macro expansion
@ C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\sugar.jl:286 [inlined]
[15] gradient(::ReverseMode{false, true, false, FFIABI, false, false}, ::var"#4#5", ::Chain{Tuple{G1Layer, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ::Const{Matrix{Float32}}, ::Const{Float32})
@ Enzyme C:\Users\Gianmarco\.julia\packages\Enzyme\Op3Un\src\sugar.jl:273
[16] top-level scope
@ C:\Users\Gianmarco\OneDrive\SymmLearn\src\bug.jl:95
[17] include(mod::Module, _path::String)
@ Base .\Base.jl:306
[18] exec_options(opts::Base.JLOptions)
@ Base .\client.jl:317
[19] _start()
@ Base .\client.jl:550
in expression starting at C:\Users\Gianmarco\OneDrive\SymmLearn\src\bug.jl:95
Is this a bug or i’m doing something wrong?