I have seen scarce material regarding Numba and Enzyme, e.g. this Numba GitHub issue or this talk at Enzyme conference 2023.
I have this simple cosine function in Python, wrapped around the numba.@cfunc
decorator. From here, I get the LLVM IR code
import numpy as np
from numba import cfunc, types
sig = types.double(
types.double,
types.double,
types.double,
types.double,
)
@cfunc(sig)
def func(a, w, p, t):
return a * np.cos(w * t + p)
if __name__ == "__main__":
_ = func.address
ir_map = func.inspect_llvm()
with open("func.ll", "w") as f:
f.write(ir_map)
The output func.ll
file looks like this
; ModuleID = 'func'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"
@_ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd = common local_unnamed_addr global i8* null
; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn writeonly
define i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) local_unnamed_addr #0 {
for.end:
%.8 = fmul double %arg.w, %arg.t
%.9 = fadd double %.8, %arg.p
%.17 = tail call double @llvm.cos.f64(double %.9)
%.22 = fmul double %.17, %arg.a
store double %.22, double* %retptr, align 8
ret i32 0
}
; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn
declare double @llvm.cos.f64(double) #1
; Function Attrs: mustprogress nofree nosync nounwind willreturn writeonly
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
%.6 = alloca double, align 8
store double 0.000000e+00, double* %.6, align 8
%.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
%.20 = load double, double* %.6, align 8
ret double %.20
}
attributes #0 = { argmemonly mustprogress nofree nosync nounwind willreturn writeonly }
attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn }
attributes #2 = { mustprogress nofree nosync nounwind willreturn writeonly }
attributes #3 = { noinline }
Now using llvm-toolset
I get the bytecode func.bc
and then get the shared library libfunc.so
via clang
llvm-as func.ll -o func.bc
clang -O3 -fPIC -shared func.bc -o libfunc.so
From here on, I read the .so
file and can call it in Julia
using Libdl
const lib = Libdl.dlopen("./libfunc.so")
const f_ptr = Libdl.dlsym(
lib, :_ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd)
f(a, w, p, t) = ccall(f_ptr, Float64, (Float64, Float64, Float64, Float64), a, w, p, t)
f(1.0, 1.0, 1.0, 1.0) # returns -0.4161468365471424
But I cannot take the gradient of this obviously.
julia> using Enzyme
julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
ERROR:
No reverse pass found for _ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd
at context: %6 = call double @_ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %0, double %1, double %2, double %3) #5, !dbg !14
I wonder if I can use Enzyme to include the reverse pass in libfunc.so
during the process.