Auto-differentiation of Numba LLVM IR using Enzyme.jl: linear algebra and opaque external functions

This is related to the thread Get Numba LLVM IR differentiated in Julia using Enzyme.jl: differentiating Python functions decorated as Numba CFunc functions. I am sorry that this thread is quite long, but I tried to be as precise as possible. I appreciate any help that could make this work as intended. Also, the majority of the debugging, prototyping and checking is done by Claude.

The goal is to differentiate Python functions whose LLVM IR is provided by Numba. In the earlier thread, it was shown for a simple cosine function that this is indeed possible. But when linear algebra is used in the body of the Python function, the LLVM IR contains numba internals that are opaque to Enzyme. For instance, the following simple example

import numba as nb
import numpy as np

sig = nb.types.float64(
    nb.types.Array(nb.types.float64, 1, 'C', True, aligned=True)
)

y = np.random.rand(2)

def outerfunc(y):
    @nb.cfunc(sig, error_model="numpy")
    def func(x):
        return np.dot(x, y)
    return func

func = outerfunc(y)

with open("numba_func.ll", "w") as f:
    f.write(func.inspect_llvm())

The following is the relevant part of the LLVM IR after llvm-extract (plus the augmented and gradient custom rules, which are linked later. See the rest of the thread to see how this part is linked to the original LLVM IR. This is the whole IR to meet the requirement of the maximum number of characters.)

; ModuleID = 'llvm-link'
source_filename = "llvm-link"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@.const.array.data = hidden constant [16 x i8] c"G\1F\E9\8E\16\B7\EF?\CA\F3\E0\8F>\D8\D5?", align 8
@".const.Error creating Python tuple from runtime exception arguments" = hidden constant [61 x i8] c"Error creating Python tuple from runtime exception arguments\00"
@".const.unknown error when calling native function" = hidden constant [43 x i8] c"unknown error when calling native function\00"
@".const.<numba.core.cpu.CPUContext object at 0x7f3d6eb54350>" = hidden constant [53 x i8] c"<numba.core.cpu.CPUContext object at 0x7f3d6eb54350>\00"
@".const.BLAS wrapper returned with an error" = hidden constant [36 x i8] c"BLAS wrapper returned with an error\00"
@.const.pickledata.139901827191872 = hidden constant [192 x i8] c"\80\04\95\B5\00\00\00\00\00\00\00\8C\08builtins\94\8C\0AValueError\94\93\94\8C;incompatible array sizes for np.dot(a, b) (vector * vector)\94\85\94\8C\0Acheck_args\94\8CC/home/devuser/.venv/lib/python3.11/site-packages/numba/np/linalg.py\94M\0C\02\87\94\87\94."
@.const.pickledata.139901827191872.sha1 = hidden constant [20 x i8] c"\11VV\9F\06\A6\DA..\BC\1E\13+\B7\ADY\0F\A6\F1\C6"
@.const.picklebuf.139901827191872 = hidden constant { i8*, i32, i8*, i8*, i32 } { i8* getelementptr inbounds ([192 x i8], [192 x i8]* @.const.pickledata.139901827191872, i32 0, i32 0), i32 192, i8* getelementptr inbounds ([20 x i8], [20 x i8]* @.const.pickledata.139901827191872.sha1, i32 0, i32 0), i8* null, i32 0 }
@llvm.compiler.used = appending global [1 x i8*] [i8* bitcast ([3 x i8*]* @__enzyme_register_gradient_numba_xxdot to i8*)], section "llvm.metadata"
@__enzyme_register_gradient_numba_xxdot = dso_local global [3 x i8*] [i8* bitcast (i32 (i8, i8, i64, i8*, i8*, i8*)* @numba_xxdot to i8*), i8* bitcast ({ i8*, i32 } (i8, i8, i8, i8, i64, i64, i8*, i8*, i8*, i8*, i8*, i8*)* @augmented_numba_xxdot to i8*), i8* bitcast (void (i8, i8, i8, i8, i64, i64, i8*, i8*, i8*, i8*, i8*, i8*, i8*)* @gradient_numba_xxdot to i8*)], align 16

define i32 @_ZN8__main__9outerfunc12_3clocals_3e4funcB2v1B56c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAg6cJqBrMeIZJFEGM0QQAE5ArrayIdLi1E1C8readonly7alignedE(double* noalias nocapture writeonly %retptr, { i8*, i32, i8*, i8*, i32 }** noalias nocapture writeonly %excinfo, i8* nocapture readnone %arg.x.0, i8* nocapture readnone %arg.x.1, i64 %arg.x.2, i64 %arg.x.3, double* %arg.x.4, i64 %arg.x.5.0, i64 %arg.x.6.0) local_unnamed_addr {
entry:
  %.95.i = alloca double, align 8
  %.107.i = alloca i32, align 4
  %0 = bitcast double* %.95.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0)
  %1 = bitcast i32* %.107.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %1)
  store double 0.000000e+00, double* %.95.i, align 8, !noalias !5
  store i32 0, i32* %.107.i, align 4, !noalias !5
  %.58.not.i.i = icmp eq i64 %arg.x.5.0, 2
  br i1 %.58.not.i.i, label %B0.endif.endif.i, label %B0.if, !prof !9

B0.endif.endif.i:                                 ; preds = %entry
  %2 = bitcast double* %.95.i to i8*
  %.101.i = bitcast double* %arg.x.4 to i8*
  %.104.i = call i32 @numba_xxdot(i8 100, i8 0, i64 2, i8* %.101.i, i8* getelementptr inbounds ([16 x i8], [16 x i8]* @.const.array.data, i64 0, i64 0), i8* nonnull %2), !noalias !5
  %.105.not.i = icmp eq i32 %.104.i, 0
  br i1 %.105.not.i, label %B0.endif, label %B0.endif.endif.if.i, !prof !10

B0.endif.endif.if.i:                              ; preds = %B0.endif.endif.i
  call void @numba_gil_ensure(i32* nonnull %.107.i), !noalias !5
  call void @Py_FatalError(i8* getelementptr inbounds ([36 x i8], [36 x i8]* @".const.BLAS wrapper returned with an error", i64 0, i64 0)), !noalias !5
  unreachable

B0.if:                                            ; preds = %entry
  %3 = bitcast i32* %.107.i to i8*
  %4 = bitcast double* %.95.i to i8*
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %4)
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %3)
  store { i8*, i32, i8*, i8*, i32 }* @.const.picklebuf.139901827191872, { i8*, i32, i8*, i8*, i32 }** %excinfo, align 8
  br label %common.ret

common.ret:                                       ; preds = %B0.endif, %B0.if
  %common.ret.op = phi i32 [ 0, %B0.endif ], [ 1, %B0.if ]
  ret i32 %common.ret.op

B0.endif:                                         ; preds = %B0.endif.endif.i
  %5 = bitcast i32* %.107.i to i8*
  %6 = bitcast double* %.95.i to i8*
  %.112.i = load double, double* %.95.i, align 8, !noalias !5
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %6)
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %5)
  store double %.112.i, double* %retptr, align 8
  br label %common.ret
}

; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #0

declare i32 @numba_xxdot(i8, i8, i64, i8*, i8*, i8*)

declare void @numba_gil_ensure(i32*) local_unnamed_addr

; Function Attrs: noreturn
declare void @Py_FatalError(i8*) local_unnamed_addr #1

; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #0

; Function Attrs: nounwind uwtable
define dso_local { i8*, i32 } @augmented_numba_xxdot(i8 noundef signext %0, i8 signext %1, i8 noundef signext %2, i8 signext %3, i64 noundef %4, i64 %5, i8* noundef %6, i8* nocapture readnone %7, i8* noundef %8, i8* nocapture readnone %9, i8* noundef %10, i8* nocapture readnone %11) #2 {
  %13 = tail call i32 @numba_xxdot(i8 noundef signext %0, i8 noundef signext %2, i64 noundef %4, i8* noundef %6, i8* noundef %8, i8* noundef %10) #5
  %14 = insertvalue { i8*, i32 } { i8* null, i32 poison }, i32 %13, 1
  ret { i8*, i32 } %14
}

; Function Attrs: argmemonly nofree nosync nounwind uwtable
define dso_local void @gradient_numba_xxdot(i8 signext %0, i8 signext %1, i8 signext %2, i8 signext %3, i64 noundef %4, i64 %5, i8* nocapture readnone %6, i8* nocapture noundef %7, i8* nocapture noundef readonly %8, i8* nocapture readnone %9, i8* nocapture readnone %10, i8* nocapture noundef %11, i8* nocapture readnone %12) #3 {
  %14 = bitcast i8* %8 to double*
  %15 = bitcast i8* %7 to double*
  %16 = bitcast i8* %11 to double*
  %17 = load double, double* %16, align 8, !tbaa !11
  store double 0.000000e+00, double* %16, align 8, !tbaa !11
  %18 = icmp sgt i64 %4, 0
  br i1 %18, label %20, label %19

19:                                               ; preds = %20, %13
  ret void

20:                                               ; preds = %20, %13
  %21 = phi i64 [ %27, %20 ], [ 0, %13 ]
  %22 = getelementptr inbounds double, double* %14, i64 %21
  %23 = load double, double* %22, align 8, !tbaa !11
  %24 = getelementptr inbounds double, double* %15, i64 %21
  %25 = load double, double* %24, align 8, !tbaa !11
  %26 = tail call double @llvm.fmuladd.f64(double %23, double %17, double %25)
  store double %26, double* %24, align 8, !tbaa !11
  %27 = add nuw nsw i64 %21, 1
  %28 = icmp eq i64 %27, %4
  br i1 %28, label %19, label %20, !llvm.loop !15
}

; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
declare double @llvm.fmuladd.f64(double, double, double) #4

attributes #0 = { argmemonly nocallback nofree nosync nounwind willreturn }
attributes #1 = { noreturn }
attributes #2 = { nounwind uwtable "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #3 = { argmemonly nofree nosync nounwind uwtable "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #4 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
attributes #5 = { nounwind }

!llvm.ident = !{!0}
!llvm.module.flags = !{!1, !2, !3, !4}

!0 = !{!"Debian clang version 15.0.7"}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{i32 7, !"PIC Level", i32 2}
!3 = !{i32 7, !"PIE Level", i32 2}
!4 = !{i32 7, !"uwtable", i32 2}
!5 = !{!6, !8}
!6 = distinct !{!6, !7, !"_ZN5numba2np6linalg10dot_2_impl12_3clocals_3e12_3clambda_3eB2v2B54c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAw4hWqFpgsIBZmgA_3dE5ArrayIdLi1E1C8readonly7alignedE5ArrayIdLi1E1C8readonly7alignedE: %retptr"}
!7 = distinct !{!7, !"_ZN5numba2np6linalg10dot_2_impl12_3clocals_3e12_3clambda_3eB2v2B54c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAw4hWqFpgsIBZmgA_3dE5ArrayIdLi1E1C8readonly7alignedE5ArrayIdLi1E1C8readonly7alignedE"}
!8 = distinct !{!8, !7, !"_ZN5numba2np6linalg10dot_2_impl12_3clocals_3e12_3clambda_3eB2v2B54c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAw4hWqFpgsIBZmgA_3dE5ArrayIdLi1E1C8readonly7alignedE5ArrayIdLi1E1C8readonly7alignedE: %excinfo"}
!9 = !{!"branch_weights", i32 2126008812, i32 21474836}
!10 = !{!"branch_weights", i32 99, i32 1}
!11 = !{!12, !12, i64 0}
!12 = !{!"double", !13, i64 0}
!13 = !{!"omnipotent char", !14, i64 0}
!14 = !{!"Simple C/C++ TBAA"}
!15 = distinct !{!15, !16, !17}
!16 = !{!"llvm.loop.mustprogress"}
!17 = !{!"llvm.loop.unroll.disable"}

numba_xxdot is an opaque external function which I am trying to make differentiable by defining its augmented and gradient functions.

numba_xxdot calls OpenBLAS shipped by scipy. Here is a C provider for numba_xxdot

/* Runtime provider for the `numba_xxdot` symbol that Numba-generated.
   Calls numba/_lapack.c, which itself calls scipy.linalg.cython_blas).

   Calls the BLAS Numba uses -> scipy's bundled OpenBLAS, whose symbols
   are prefixed `scipy_` and are LP64 (32-bit int). We use the cblas `_sub` forms
   (result via pointer) to avoid the Fortran complex-return-value ABI pitfall.
   Returns 0 (success) as the i32 status of the Numba IR checks.
*/

#include <stddef.h>

extern float  scipy_cblas_sdot(int, const void *, int, const void *, int);
extern double scipy_cblas_ddot(int, const void *, int, const void *, int);
extern void scipy_cblas_cdotu_sub(int, const void *, int, const void *, int, void *);
extern void scipy_cblas_cdotc_sub(int, const void *, int, const void *, int, void *);
extern void scipy_cblas_zdotu_sub(int, const void *, int, const void *, int, void *);
extern void scipy_cblas_zdotc_sub(int, const void *, int, const void *, int, void *);

int numba_xxdot(char kind, char conjugate, size_t n,
                const void *dx, const void *dy, void *out) {
    int _n = (int)n;
    switch (kind) {
        case 's': *(float *)out  = scipy_cblas_sdot(_n, dx, 1, dy, 1); break;
        case 'd': *(double *)out = scipy_cblas_ddot(_n, dx, 1, dy, 1); break;
        case 'c':
            (conjugate ? scipy_cblas_cdotc_sub : scipy_cblas_cdotu_sub)(_n, dx, 1, dy, 1, out);
            break;
        case 'z':
            (conjugate ? scipy_cblas_zdotc_sub : scipy_cblas_zdotu_sub)(_n, dx, 1, dy, 1, out);
            break;
        default: return 1;   /* unknown kind -> error status */
    }
    return 0;
}

Then I compile it into a shared object, namely libnumba_xxdot.so, and load it into RTLD_GLOBAL

CLANG="${CLANG:-clang}"
SCIPY_LIBS="$(python3 -c 'import scipy, os; print(os.path.dirname(scipy.__file__) + ".libs")')"
OPENBLAS="$(ls "$SCIPY_LIBS"/libscipy_openblas-*.so | head -1)"
$CLANG -O3 -fPIC -shared numba_xxdot_provider.c -o libnumba_xxdot.so \
    -L"$SCIPY_LIBS" -l:"$(basename "$OPENBLAS")" -Wl,-rpath,"$SCIPY_LIBS"
using Libdl
const WDIR = @__DIR__

Libdl.dlopen(joinpath(WDIR, "libnumba_xxdot.so"), Libdl.RTLD_GLOBAL)

This way, I can call the LLVM IR via ccall, and the result agrees with the reference value calculated directly in Julia.

const ENTRY = let ll = read(joinpath(WDIR, "numba_func_extracted.ll"), String)
    m = match(r"define i32 @(_ZN8__main__[^(]+)\(", ll)
    m === nothing && error("inner entry symbol not found in numba_func.ll")
    String(m.captures[1])
end
const FPTR = Libdl.dlsym(KLIB, Symbol(ENTRY))

function numba_dot(x::Vector{Float64})
    result = Ref{Float64}(0.0)
    excinfo = Ref{Ptr{Cvoid}}(C_NULL)
    n = length(x)
    status = ccall(
        FPTR, Cint,
        (
            Ptr{Float64}, Ptr{Ptr{Cvoid}}, Ptr{Cvoid}, Ptr{Cvoid},
            Int64, Int64, Ptr{Float64}, Int64, Int64,
        ),
        result, excinfo, C_NULL, C_NULL, n, 8, x, n, 1
    )
    status == 0 || error("numba func raised (status=$status); x length must match baked y (=2)")
    return result[]
end


y_rec = [numba_dot([1.0, 0.0]), numba_dot([0.0, 1.0])]
println("entry symbol      : ", ENTRY)
println("recovered baked y : ", y_rec)

x = rand(2)
val = numba_dot(x)
ref = x[1] * y_rec[1] + x[2] * y_rec[2]   # dot(x, y)
println("x                 : ", x)
println("func(x) via ccall : ", val)
println("dot(x, y_rec)     : ", ref)
println(abs(val - ref) < 1.0e-12 ? "PASS" : "FAIL")

and the result is

$ julia --project call_numba.jl                                                                                                                                                             
entry symbol      : _ZN8__main__9outerfunc12_3clocals_3e4funcB2v1B56c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAg6cJqBrMeIZJFEGM0QQAE5ArrayIdLi1E1C8readonly7alignedE
recovered baked y : [0.9910996237967787, 0.34132350969940417]                                                                                                                                                      
x                 : [0.7165329173376253, 0.6884605753942334]                                             
func(x) via ccall : 0.9451432846945598                                                                                                                                                                             
dot(x, y_rec)     : 0.9451432846945598                                                                                                                                                                             
PASS    

Now I want to provide Enzyme with a custom rule for the forward and reverse paths of numba_xxdot via __enzyme_register_gradient_numba_xxdot

#include <stdint.h>

/* Opaque primal (i32-returning); body provided at runtime by libnumba_xxdot.so. */
extern int numba_xxdot(char kind, char conjugate, int64_t n,
                       const void *x, const void *y, void *result);

typedef struct { void *tape; int status; } xxdot_aug_t;

xxdot_aug_t augmented_numba_xxdot(char kind,      char d_kind,
                                  char conjugate, char d_conjugate,
                                  int64_t n,       int64_t d_n,
                                  const void *x,   void *d_x,
                                  const void *y,   void *d_y,
                                  void *result,    void *d_result) {
    (void)d_kind; (void)d_conjugate; (void)d_n; (void)d_x; (void)d_y; (void)d_result;
    int status = numba_xxdot(kind, conjugate, n, x, y, result);   /* primal */
    xxdot_aug_t out = { (void *)0, status };                      /* no tape; pass status */
    return out;
}

void gradient_numba_xxdot(char kind,      char d_kind,
                          char conjugate, char d_conjugate,
                          int64_t n,       int64_t d_n,
                          const void *x,   void *d_x,
                          const void *y,   void *d_y,
                          void *result,    void *d_result,
                          void *tape) {
    (void)kind; (void)d_kind; (void)conjugate; (void)d_conjugate; (void)d_n;
    (void)x; (void)d_y; (void)result; (void)tape;
    const double *yv  = (const double *)y;
    double       *dxv = (double *)d_x;
    double rbar = *(double *)d_result;
    *(double *)d_result = 0.0;                         /* consume the output adjoint */
    for (int64_t i = 0; i < n; i++) dxv[i] += yv[i] * rbar;
}

__attribute__((used))
void *__enzyme_register_gradient_numba_xxdot[] = {
    (void *)numba_xxdot,
    (void *)augmented_numba_xxdot,
    (void *)gradient_numba_xxdot,
};

Now I emit the LLVM IR of the rules above and link it to numba_func_extracted.ll

RULE="${RULE:-xxdot_rule_i32}"
EXTRACTED="numba_func_extracted.ll"

$CLANG -S -emit-llvm -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops \
    -Xclang -no-opaque-pointers "$RULE.c" -o "$RULE.ll"
$LLVM_LINK "$EXTRACTED" "$RULE.ll" -S -o numba_func_combined.ll
echo "linked $EXTRACTED + $RULE.ll -> numba_func_combined.ll"

$CLANG -x ir -O2 -fPIC -shared -fembed-bitcode -Xclang -no-opaque-pointers \
    numba_func_combined.ll -o numba_func.so

which results in the LLVM IR shown above

Now, if I try to get the gradient of the numba_dot function, shown in the Julia script above, through Enzyme, I get the following error

using Enzyme

dx = Enzyme.gradient(Enzyme.Reverse, numba_dot, x)[1]
┌ Warning: <unknown>:0:0: in function numba_xxdot i32 (i8, i8, i64, i8*, i8*, i8*): Massaging provided custom augmented forward pass to handle constant argumented
└ @ LLVM ~/.julia/packages/LLVM/upRII/src/core/context.jl:170
julia: /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:5969: void AdjointGenerator::recursivelyHandleSubfunction(llvm::CallInst&, llvm::Function*, bool, const std::vector<bool>&, bool, DIFFE_TYPE, bool): Assertion `argsInverted[i] == DIFFE_TYPE::DUP_NONEED' failed.

[17731] signal 6 (-6): Aborted
in expression starting at /workspace/numba/call_numba.jl:60
unknown function (ip: 0x7ff66906c95c)
gsignal at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x7ff66900041f)
recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:5969
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6746
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4520
recursivelyHandleSubfunction at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:5974
visitCallInst at /workspace/srcdir/Enzyme/enzyme/Enzyme/AdjointGenerator.h:6746
visit at /opt/x86_64-linux-gnu/x86_64-linux-gnu/sys-root/usr/local/include/llvm/IR/InstVisitor.h:111 [inlined]
CreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/EnzymeLogic.cpp:4520
EnzymeCreatePrimalAndGradient at /workspace/srcdir/Enzyme/enzyme/Enzyme/CApi.cpp:709
EnzymeCreatePrimalAndGradient at /home/devuser/.julia/dev/Enzyme/src/api.jl:270
jfptr_EnzymeCreatePrimalAndGradient_27651 at /home/devuser/.julia/compiled/v1.11/Enzyme/G1p5n_JANKt.so (unknown line)
macro expansion at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:2826 [inlined]
macro expansion at /home/devuser/.julia/packages/LLVM/upRII/src/base.jl:113 [inlined]
enzyme! at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:2697
unknown function (ip: 0x7ff65d1e1d6d)
compile_unhooked at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:5932
#compile#153 at /home/devuser/.julia/packages/GPUCompiler/KwfWk/src/driver.jl:67 [inlined]
compile at /home/devuser/.julia/packages/GPUCompiler/KwfWk/src/driver.jl:55 [inlined]
_thunk at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:6865
_thunk at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:6863 [inlined]
cached_compilation at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:6933 [inlined]
thunkbase at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:7049
thunk_generator at /home/devuser/.julia/dev/Enzyme/src/compiler.jl:7193
unknown function (ip: 0x7ff65d1a67a0)
jl_call_staged at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/method.c:601
ijl_code_for_staged at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/method.c:656
get_staged at ./compiler/utilities.jl:123
retrieve_code_info at ./compiler/utilities.jl:135 [inlined]
InferenceState at ./compiler/inferencestate.jl:497
typeinf_edge at ./compiler/typeinfer.jl:913
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2423
abstract_eval_call at ./compiler/abstractinterpretation.jl:2438
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2454
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2752
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3044
typeinf_local at ./compiler/abstractinterpretation.jl:3331
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3413
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_apply at ./compiler/abstractinterpretation.jl:1690
abstract_call_known at ./compiler/abstractinterpretation.jl:2102
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2423
abstract_eval_call at ./compiler/abstractinterpretation.jl:2438
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2454
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2752
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3068
typeinf_local at ./compiler/abstractinterpretation.jl:3331
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3413
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_edge at ./compiler/typeinfer.jl:923
abstract_call_method at ./compiler/abstractinterpretation.jl:660
abstract_call_gf_by_type at ./compiler/abstractinterpretation.jl:102
abstract_call_known at ./compiler/abstractinterpretation.jl:2200
abstract_call at ./compiler/abstractinterpretation.jl:2282
abstract_call at ./compiler/abstractinterpretation.jl:2275
abstract_call at ./compiler/abstractinterpretation.jl:2423
abstract_eval_call at ./compiler/abstractinterpretation.jl:2438
abstract_eval_statement_expr at ./compiler/abstractinterpretation.jl:2454
abstract_eval_statement at ./compiler/abstractinterpretation.jl:2752
abstract_eval_basic_statement at ./compiler/abstractinterpretation.jl:3044
typeinf_local at ./compiler/abstractinterpretation.jl:3331
typeinf_nocycle at ./compiler/abstractinterpretation.jl:3413
_typeinf at ./compiler/typeinfer.jl:244
typeinf at ./compiler/typeinfer.jl:215
typeinf_ext at ./compiler/typeinfer.jl:1101
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1139
typeinf_ext_toplevel at ./compiler/typeinfer.jl:1135
jfptr_typeinf_ext_toplevel_39964.1 at /home/devuser/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_type_infer at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/gf.c:390
jl_generate_fptr_impl at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/jitlayers.cpp:519
jl_compile_method_internal at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/gf.c:2536 [inlined]
jl_compile_method_internal at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/gf.c:2423
_jl_invoke at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/gf.c:2940 [inlined]
ijl_apply_generic at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/gf.c:3125
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:666
jl_interpret_toplevel_thunk at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:824
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
include_string at ./loading.jl:2734
_include at ./loading.jl:2794
include at ./Base.jl:562
jfptr_include_46943.1 at /home/devuser/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
exec_options at ./client.jl:323
_start at ./client.jl:531
jfptr__start_73597.1 at /home/devuser/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x7ff669001ca7)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 39455723 (Pool: 39455244; Big: 479); GC: 26
[1]    17731 IOT instruction (core dumped)  julia --project=../../julia/python_env call_numba.jl

Claude says that the activity of the returned value is defaulted to DUP_NONEED when a custom rule is passed, whereas in this case numba_xxdot returns a CONSTANT i32. I tried to modify the activity marker of the return value to enzyme_inactive and enzyme_constant_return via LLVM.jl to no avail, and I get the following error

PASS RAISED: LLVM error: <unknown>:0:0: in function preprocess__ZN8__main__9outerfunc12_3clocals_3e4funcB2v1B56c8tJTIeFIjxB2IKSgI4CrvQCk0Z4yRYcWsBAg6cJqBrMeIZJFEGM0QQAE5ArrayIdLi1E1C8readonly7alignedE i32 (double*, { i8*, i32, i8*, i8*, i32 }**, i8*, i8*, i64, i64, double*, i64, i64): Enzyme: The required return activity calling into function: numba_xxdot was CONSTANT but the assumed (default) return activity was DUP_ARG
 at context:   %.104.i = call i32 @numba_xxdot(i8 100, i8 0, i64 2, i8* %.101.i, i8* getelementptr inbounds ([16 x i8], [16 x i8]* @.const.array.data, i64 0, i64 0), i8* nonnull %2) #10, !noalias !5

So marking the return type doesn’t change the behaviour at all. And it seems Enzyme always falls back to the default DUP_ARG return activity. I wonder if DUP_ARG is somehow hard-coded in AdjointGenerator.h for custom rules, or if I am doing something wrong here.

1 Like