Get Numba LLVM IR differentiated in Julia using Enzyme.jl

I have seen scarce material regarding Numba and Enzyme, e.g. this Numba GitHub issue or this talk at Enzyme conference 2023.

I have this simple cosine function in Python, wrapped around the numba.@cfunc decorator. From here, I get the LLVM IR code

import numpy as np
from numba import cfunc, types

sig = types.double(
    types.double,
    types.double,
    types.double,
    types.double,
)

@cfunc(sig)
def func(a, w, p, t):
    return a * np.cos(w * t + p)

if __name__ == "__main__":
    _ = func.address

    ir_map = func.inspect_llvm()

    with open("func.ll", "w") as f:
        f.write(ir_map)

The output func.ll file looks like this

; ModuleID = 'func'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"

@_ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd = common local_unnamed_addr global i8* null

; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn writeonly
define i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) local_unnamed_addr #0 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9)
  %.22 = fmul double %.17, %arg.a
  store double %.22, double* %retptr, align 8
  ret i32 0
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn
declare double @llvm.cos.f64(double) #1

; Function Attrs: mustprogress nofree nosync nounwind willreturn writeonly
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, double* %.6, align 8
  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
  %.20 = load double, double* %.6, align 8
  ret double %.20
}

attributes #0 = { argmemonly mustprogress nofree nosync nounwind willreturn writeonly }
attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn }
attributes #2 = { mustprogress nofree nosync nounwind willreturn writeonly }
attributes #3 = { noinline }

Now using llvm-toolset I get the bytecode func.bc and then get the shared library libfunc.so via clang

llvm-as func.ll -o func.bc
clang -O3 -fPIC -shared func.bc -o libfunc.so

From here on, I read the .so file and can call it in Julia

using Libdl

const lib = Libdl.dlopen("./libfunc.so")
const f_ptr = Libdl.dlsym(
    lib, :_ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd)

f(a, w, p, t) = ccall(f_ptr, Float64, (Float64, Float64, Float64, Float64), a, w, p, t)

f(1.0, 1.0, 1.0, 1.0) # returns -0.4161468365471424

But I cannot take the gradient of this obviously.

julia> using Enzyme
julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
ERROR:
No reverse pass found for _ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd
 at context:   %6 = call double @_ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %0, double %1, double %2, double %3) #5, !dbg !14

I wonder if I can use Enzyme to include the reverse pass in libfunc.so during the process.

2 Likes

if you add -fembed-bitcode like in Calling C Code with Automatic Differentiation in Julia - #6 by wsmoses

does it work?

Thanks for the hint, but that didn’t help. To be precise, executing the following

clang -O3 -fPIC -fembed-bitcode -shared func.bc -o libfunc.so

still throws the same error when I want to get the gradient.

ERROR:
No reverse pass found for _ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd
 at context:   %6 = call double @_ZN8__main__4funcB2v3B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %0, double %1, double %2, double %3) #5, !dbg !14


I think you might need to do a ccall like in the linked code rather than call dlopen?

It emitted the same error

julia> using Enzyme
julia> function f(a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)
           @ccall "./libfunc.so"._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(
               a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)::Cdouble
       end
f (generic function with 1 method)

julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
ERROR:
No reverse pass found for ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so
 at context:   %6 = call double @"ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so"(double %0, double %1, double %2, double %3) #5, !dbg !15

Also, when I inspect the list of symbols from the object file libfunc.so with or without the -fembed-bitcode, the symbols are the same

❯ nm -D libfunc.so
0000000000001140 T cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd
                 U cos
                 w __cxa_finalize@GLIBC_2.2.5
                 w __gmon_start__
                 w _ITM_deregisterTMCloneTable
                 w _ITM_registerTMCloneTable
0000000000004020 B _ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd
0000000000001100 T _ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd

I played around a bit to inspect the shared object file and to ensure the LLVM bit code is there; I could reconstruct the original LLVM IR correctly

❯ readelf -W -S libfunc.so
There are 28 section headers, starting at offset 0x40b8:

Section Headers:
  [Nr] Name              Type            Address          Off    Size   ES Flg Lk Inf Al
  [ 0]                   NULL            0000000000000000 000000 000000 00      0   0  0
  [ 1] .gnu.hash         GNU_HASH        00000000000001c8 0001c8 000030 00   A  2   0  8
  [ 2] .dynsym           DYNSYM          00000000000001f8 0001f8 0000d8 18   A  3   1  8
  [ 3] .dynstr           STRTAB          00000000000002d0 0002d0 000123 00   A  0   0  1
  [ 4] .gnu.version      VERSYM          00000000000003f4 0003f4 000012 02   A  2   0  2
  [ 5] .gnu.version_r    VERNEED         0000000000000408 000408 000020 00   A  3   1  8
  [ 6] .rela.dyn         RELA            0000000000000428 000428 0000a8 18   A  2   0  8
  [ 7] .rela.plt         RELA            00000000000004d0 0004d0 000030 18  AI  2  18  8
  [ 8] .init             PROGBITS        0000000000001000 001000 00001b 00  AX  0   0  4
  [ 9] .plt              PROGBITS        0000000000001020 001020 000030 10  AX  0   0 16
  [10] .plt.got          PROGBITS        0000000000001050 001050 000008 08  AX  0   0  8
  [11] .text             PROGBITS        0000000000001060 001060 000108 00  AX  0   0 16
  [12] .fini             PROGBITS        0000000000001168 001168 00000d 00  AX  0   0  4
  [13] .eh_frame         PROGBITS        0000000000002000 002000 000004 00   A  0   0  8
  [14] .init_array       INIT_ARRAY      0000000000003df8 002df8 000008 08  WA  0   0  8
  [15] .fini_array       FINI_ARRAY      0000000000003e00 002e00 000008 08  WA  0   0  8
  [16] .dynamic          DYNAMIC         0000000000003e08 002e08 0001c0 10  WA  3   0  8
  [17] .got              PROGBITS        0000000000003fc8 002fc8 000020 08  WA  0   0  8
  [18] .got.plt          PROGBITS        0000000000003fe8 002fe8 000028 08  WA  0   0  8
  [19] .data             PROGBITS        0000000000004010 003010 000008 00  WA  0   0  8
  [20] .bss              NOBITS          0000000000004018 003018 000010 00  WA  0   0  8
  [21] .comment          PROGBITS        0000000000000000 003018 000027 01  MS  0   0  1
  [22] .llvmbc           PROGBITS        0000000000000000 00303f 000974 00      0   0  1
  [23] .llvmcmd          PROGBITS        0000000000000000 0039b3 0000a8 00      0   0  1
  [24] .gnu.build.attributes NOTE            0000000000006028 003a5c 0000d8 00      0   0  4
  [25] .symtab           SYMTAB          0000000000000000 003b38 000288 18     26  19  8
  [26] .strtab           STRTAB          0000000000000000 003dc0 000208 00      0   0  1
  [27] .shstrtab         STRTAB          0000000000000000 003fc8 0000ef 00      0   0  1
Key to Flags:
  W (write), A (alloc), X (execute), M (merge), S (strings), I (info),
  L (link order), O (extra OS processing required), G (group), T (TLS),
  C (compressed), x (unknown), o (OS specific), E (exclude),
  D (mbind), l (large), p (processor specific)

I can see that entries 22 and 23 refer to the LLVM bitcode, which are absent if I didn’t use -fembed-bitcode flag. Then extract the bitcode from libfunc.so via

❯ objcopy --dump-section .llvmbc=extracted.bc libfunc.so
❯ llvm-dis extracted.bc -o -
; ModuleID = 'extracted.bc'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"

@_ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd = common local_unnamed_addr global ptr null

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(argmem: write)
define i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(ptr noalias nocapture writeonly %retptr, ptr noalias nocapture readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) local_unnamed_addr #0 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9)
  %.22 = fmul double %.17, %arg.a
  store double %.22, ptr %retptr, align 8
  ret i32 0
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.cos.f64(double) #1

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(write)
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, ptr %.6, align 8
  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(ptr nonnull %.6, ptr nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
  %.20 = load double, ptr %.6, align 8
  ret double %.20
}

attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(argmem: write) }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(write) }
attributes #3 = { noinline }

So the LLVM IR code is definitely there, but I am afraid Enzyme still fails to get the gradient

❯ julia --project=.
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.11.6 (2025-07-09)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> using Enzyme

julia> function f(a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)
           @ccall "./libfunc.so"._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(
               a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)::Cdouble
       end
f (generic function with 1 method)

julia> f(1.0, 1.0, 1.0, 1.0)
-0.4161468365471424

julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
ERROR:
No reverse pass found for ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so
 at context:   %6 = call double @"ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so"(double %0, double %1, double %2, double %3) #5, !dbg !15
2 Likes

So that indicates that Enzyme.jl didn’t find the bitcode to import. Since it seems like you enjoy hacking on things, the relevant part of the Enzyme.jl code that tries to find the embedded bitcode is here : Enzyme.jl/src/compiler/validation.jl at afb9f1d89f6644ba8d3a2992a260618f087e335c · EnzymeAD/Enzyme.jl · GitHub . Maybe add some prints and see what’s going wrong?

1 Like

Thanks a lot! Your hint took me to the right spot to hunt down the issue.

Actually, the bitcode was read successfully from the .so file and stored into data as a Vector{UInt8}. So the llvmbc was returned at line 535 in validation.jl.

But, since it was generated with opaque pointers, Base.parse(LLVM.Module, data) at line 543, threw the following error

julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
mod = LLVM.Module("start")
flib = "./libfunc.so"
fname = "_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd"
imported = Set{String}()
"entered the try branch" = "entered the try branch"
"entered the catch e2" = "entered the catch e2"
e2 = LLVM.LLVMException("Opaque pointers are only supported in -opaque-pointers mode (Producer: 'LLVM16.0.6' Reader: 'LLVM 16.0.6jl')")
"entered catch e" = "entered catch e"
e = UndefVarError(:LLVMDowngrader_jll, Enzyme.Compiler)
ERROR:
No reverse pass found for ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so
 at context:   %6 = call double @"ejlstr$_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd$./libfunc.so"(double %0, double %1, double %2, double %3) #5, !dbg !15

I used clang and llvm-tools version v16.0.6 to get the .bc and the .so files. I noticed Numba generated the LLVM IR using typed-pointers through llvmlite. Here is the func.ll file.

; ModuleID = 'func'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"

@_ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd = common local_unnamed_addr global i8* null

; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn writeonly
define i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) local_unnamed_addr #0 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9)
  %.22 = fmul double %.17, %arg.a
  store double %.22, double* %retptr, align 8
  ret i32 0
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn
declare double @llvm.cos.f64(double) #1

; Function Attrs: mustprogress nofree nosync nounwind willreturn writeonly
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, double* %.6, align 8
  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
  %.20 = load double, double* %.6, align 8
  ret double %.20
}

attributes #0 = { argmemonly mustprogress nofree nosync nounwind willreturn writeonly }
attributes #1 = { mustprogress nocallback nofree nosync nounwind readnone speculatable willreturn }
attributes #2 = { mustprogress nofree nosync nounwind willreturn writeonly }
attributes #3 = { noinline }

Under the hood, llvmlite uses LLVM v15. But llvm-as promoted the typed pointers to opaque pointers. Here is the output when I dissected func.bc.

❯ llvm-dis func.bc -o -
; ModuleID = 'func.bc'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-conda-linux-gnu"

@_ZN08NumbaEnv8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd = common local_unnamed_addr global ptr null

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(argmem: write)
define i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(ptr noalias nocapture writeonly %retptr, ptr noalias nocapture readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) local_unnamed_addr #0 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9)
  %.22 = fmul double %.17, %arg.a
  store double %.22, ptr %retptr, align 8
  ret i32 0
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare double @llvm.cos.f64(double) #1

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(write)
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, ptr %.6, align 8
  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(ptr nonnull %.6, ptr nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
  %.20 = load double, ptr %.6, align 8
  ret double %.20
}

attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(argmem: write) }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(write) }
attributes #3 = { noinline }

So I set the flag JULIA_LLVM_ARGS="-opaque-pointers=1" as suggested here, but to no avail.

I then downgraded llvm-tools and clang to v15.0.7 and got the typed-pointer LLVM IR bitcode in the func.bc file. That bypassed the earlier issue with the parsing, but I encountered a new error

ERROR: LLVM error: REPL[2]:2:0: in function preprocess_julia_f_8476 double (double, double, double, double): Enzyme: Number of arg operands != function parameters
  %6 = call double bitcast (i32 (double*, { i8*, i32, i8*, i8*, i32 }**, double, double, double, double)* @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd to double (double, double, double, double)*)(double %0, double %1, double %2, double %3) #6, !dbg !15

See the attachment for the whole error message. It exceeded the tonal number of characters allowed in the body of the message.
err.toml (43.7 KB)

Any suggestions on how I can further investigate the issue?
Also, how can I set the opaque-pointers flag correctly in Julia? Did I miss something?

julia> versioninfo()
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 Ă— AMD Ryzen 5 7640U w/ Radeon 760M Graphics
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver4)
Threads: 1 default, 0 interactive, 1 GC (on 12 virtual cores)
3 Likes

okay that’s an actual error (that incidentally could mean even without enzyme your ccall could segfualt).

Specifically see:


  ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal noundef i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture nofree readnone %excinfo, double %arg.a, double %arg.w, double %arg.p, double %arg.t) unnamed_addr #3 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9)
  %.22 = fmul double %.17, %arg.a
  store double %.22, double* %retptr, align 8, !noalias !16
  ret i32 0
}

  %6 = call double bitcast (i32 (double*, { i8*, i32, i8*, i8*, i32 }**, double, double, double, double)* @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd to double (double, double, double, double)*)(double %0, double %1, double %2, double %3), !dbg !15

your original code had two arguments at the start, a double* [for returning the result by reference], and some execinfo, that you aren’t passing to the function you’re calling. Also the result isn’t a double, its an i32 [presumably a success status].

if you fix your call on the julia side, presumably this will work

3 Likes

You made it so easy for me to fix it! Indeed I was making a call to the Numba’s internal function, namely _ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd which returns an i32, 0 for the success and 1 for Python exception. The first argument double* %retptr is where the result is stored, the second argument { i8*, i32, i8*, i8*, i32 }** is a pointer for Numba’s exception info block (%excinfo), and then the rest four arguments for a, w, p and t as double’s.

Whereas, if I made a call to the @cfunc wrapper directly, I only needed to handle the four dobule arguments for a, w, p and t. The pointers to the %retptr and the %excinf were handled internally

; Function Attrs: mustprogress nofree nosync nounwind willreturn writeonly
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, double* %.6, align 8
  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
  %.20 = load double, double* %.6, align 8
  ret double %.20
}

To expand a bit on this @cfunc wrapper,

  %.6 = alloca double, align 8
  store double 0.000000e+00, double* %.6, align 8

is where we allocate the double output (8-byte aligned) and store the value 0.000000e+00; assign it variable %.6 Then we have

  %.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3

where we pass the pointer of %.6 as the first argument, a poison or dummy pointer for the exception info block { i8*, i32, i8*, i8*, i32 }** and afterwards the four double’s for the variables a, w, p and t.

Now I have two solutions

  1. Make a call to the @cfunc wrapper, which is the most straightforward one.
  2. Define a dobule pointer and a dummy pointer for the excinfo part and pass them as the arguments of Numba’s nopython ABI implementation.

Solution 1

julia> using Libdl

julia> using Enzyme

julia> const lib = Libdl.dlopen("./libfunc.so")
Ptr{Nothing} @0x0000000009d707f0

julia> const f_ptr = Libdl.dlsym(
           lib, Symbol(
               "cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd"
           )
       )
Ptr{Nothing} @0x00007fcc3aa21140

julia> f(a, w, p, t) = ccall(f_ptr, Float64, (Float64, Float64, Float64, Float64), a, w, p, t)
f (generic function with 1 method)

julia> f(1.0, 1.0, 1.0, 1.0)
-0.4161468365471424

julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
(-2.88904333752103e-310, nothing, nothing, nothing)

This returns a wrong result for the gradient with respect to the amplitude f(t; a, w, p) = a \cos(w t + p) \rightarrow \partial f/ \partial a = \cos(wt + p)

Solution 2

julia> using Libdl

julia> using Enzyme

julia> const lib = Libdl.dlopen("./libfunc.so")
Ptr{Nothing} @0x0000000014c181a0

julia> const g_ptr = Libdl.dlsym(
           lib, :_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd)
Ptr{Nothing} @0x00007ff1a04eb100

julia> function g(a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)
           result = Ref{Cdouble}()
           exc_info = Ref{Ptr{Nothing}}()

           status = ccall(g_ptr, Cint,
               (Ptr{Cdouble}, Ptr{Ptr{Nothing}}, Cdouble, Cdouble, Cdouble, Cdouble),
               result, exc_info, a, w, p, t)

           status == 0 || error("Python exception raised!")
           result[]
       end
g (generic function with 1 method)

julia> g(1.0, 1.0, 1.0, 1.0)
-0.4161468365471424

julia> gradient(Reverse, g, 1.0, Const(1.0), Const(1.0), Const(1.0))
(-0.4161468365471424, nothing, nothing, nothing)

which returns the correct gradient!

So, thanks a lot for helping me out @wsmoses really appreciate your patience and showing interest in this. And one last favour, if you could tell me why the gradient for the @cfunc wrapper is wrong.

1 Like

Re the first one, can you run Enzyme.API.printall!(true) right after first loading Enzyme. This will output the IR Enzyme sees, before (and after) differentiation. So we can see what’s happening more closely

Here you are

julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))

after simplification :
; Function Attrs: mustprogress nofree willreturn memory(read, argmem: none, inaccessiblemem: none)
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_6502(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %2, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %3) local_unnamed_addr #6 !dbg !18 {
top:
  %pgcstack = call {}*** @julia.get_pgcstack() #7
  %ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %4 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !8
  %5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #7, !dbg !19
  fence syncscope("singlethread") seq_cst
  %6 = call fastcc double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %0, double %1, double %2, double %3) #8, !dbg !19
  ret double %6, !dbg !19
}

after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none)
define internal fastcc double @preprocess_cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) unnamed_addr #4 {
entry:
  %.6 = alloca double, align 8
  store double 0.000000e+00, double* %.6, align 8, !noalias !15
  call fastcc void @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %.6, double %.1, double %.2, double %.3, double %.4) #7
  %.20 = load double, double* %.6, align 8
  ret double %.20
}

after simplification :
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define internal fastcc void @preprocess__ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture nofree noundef nonnull writeonly align 8 dereferenceable(8) %retptr, double %arg.a, double %arg.w, double %arg.p, double %arg.t) unnamed_addr #5 {
for.end:
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9) #7
  %.22 = fmul double %.17, %arg.a
  store double %.22, double* %retptr, align 8, !noalias !15
  ret void
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffe_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* noalias nocapture nofree writeonly align 8 dereferenceable(8) %retptr, double* nocapture nofree align 8 %"retptr'", double %arg.a, double %arg.w, double %arg.p, double %arg.t) unnamed_addr #7 {
for.end:
  %".22'de" = alloca double, align 8
  %0 = getelementptr double, double* %".22'de", i64 0
  store double 0.000000e+00, double* %0, align 8
  %"arg.a'de" = alloca double, align 8
  %1 = getelementptr double, double* %"arg.a'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %.8 = fmul double %arg.w, %arg.t
  %.9 = fadd double %.8, %arg.p
  %.17 = tail call double @llvm.cos.f64(double %.9) #8
  br label %invertfor.end

invertfor.end:                                    ; preds = %for.end
  %2 = load double, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
  store double 0.000000e+00, double* %"retptr'", align 8, !alias.scope !27, !noalias !30
  %3 = load double, double* %".22'de", align 8
  %4 = fadd fast double %3, %2
  store double %4, double* %".22'de", align 8
  %5 = load double, double* %".22'de", align 8
  store double 0.000000e+00, double* %".22'de", align 8
  %6 = fmul fast double %5, %.17
  %7 = load double, double* %"arg.a'de", align 8
  %8 = fadd fast double %7, %6
  store double %8, double* %"arg.a'de", align 8
  %9 = load double, double* %"arg.a'de", align 8
  %10 = insertvalue { double } undef, double %9, 0
  ret { double } %10
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind
define internal fastcc { double } @diffecfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4, double %differeturn) unnamed_addr #7 {
entry:
  %".20'de" = alloca double, align 8
  %0 = getelementptr double, double* %".20'de", i64 0
  store double 0.000000e+00, double* %0, align 8
  %".1'de" = alloca double, align 8
  %1 = getelementptr double, double* %".1'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %".6'ipa" = alloca double, align 8
  store double 0.000000e+00, double* %".6'ipa", align 8
  br label %invertentry

invertentry:                                      ; preds = %entry
  store double %differeturn, double* %".20'de", align 8
  %2 = load double, double* %".20'de", align 8
  store double 0.000000e+00, double* %".20'de", align 8
  %3 = load double, double* %".6'ipa", align 8, !alias.scope !22, !noalias !25
  %4 = fadd fast double %3, %2
  store double %4, double* %".6'ipa", align 8, !alias.scope !22, !noalias !25
  %5 = call fastcc { double } @diffe_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nocapture nofree writeonly align 8 undef, double* nocapture nofree align 8 %".6'ipa", double %.1, double %.2, double %.3, double %.4)
  %6 = extractvalue { double } %5, 0
  %7 = load double, double* %".1'de", align 8
  %8 = fadd fast double %7, %6
  store double %8, double* %".1'de", align 8
  store double 0.000000e+00, double* %".6'ipa", align 8, !alias.scope !22, !noalias !27
  %9 = load double, double* %".1'de", align 8
  %10 = insertvalue { double } undef, double %9, 0
  ret { double } %10
}

; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_6502(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %2, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="140025427211840" "enzymejl_parmtype_ref"="0" %3, double %differeturn) local_unnamed_addr #7 !dbg !20 {
top:
  %"'de" = alloca double, align 8
  %4 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de1" = alloca double, align 8
  %5 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %5, align 8
  %pgcstack = call {}*** @julia.get_pgcstack() #10
  %ptls_field3 = getelementptr inbounds {}**, {}*** %pgcstack, i64 2
  %6 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %6, align 8, !tbaa !8, !alias.scope !21, !noalias !24
  %7 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %7, align 8, !tbaa !12, !alias.scope !26, !noalias !29
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #10, !dbg !31
  fence syncscope("singlethread") seq_cst
  br label %inverttop, !dbg !31

inverttop:                                        ; preds = %top
  store double %differeturn, double* %"'de", align 8
  %8 = load double, double* %"'de", align 8, !dbg !31
  %9 = call fastcc { double } @diffecfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %0, double %1, double %2, double %3, double %8), !dbg !31
  %10 = extractvalue { double } %9, 0, !dbg !31
  %11 = load double, double* %"'de1", align 8, !dbg !31
  %12 = fadd fast double %11, %10, !dbg !31
  store double %12, double* %"'de1", align 8, !dbg !31
  store double 0.000000e+00, double* %"'de", align 8, !dbg !31
  %13 = load double, double* %"'de1", align 8
  %14 = insertvalue { double } undef, double %13, 0
  ret { double } %14
}

(-2.8789781999412e-310, nothing, nothing, nothing)


Hm in a first skim I don’t see anything obivious, can you simplify the example as much as you can while still erring, and post it as an issue on Enzyme.jl?

I will do that. I will post the link to the issue here once I am done testing with a simpler example and reproducing the error.

Sorry it took a while. Here is the link to the issue