You made it so easy for me to fix it! Indeed I was making a call to the Numba’s internal function, namely _ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd which returns an i32, 0 for the success and 1 for Python exception. The first argument double* %retptr is where the result is stored, the second argument { i8*, i32, i8*, i8*, i32 }** is a pointer for Numba’s exception info block (%excinfo), and then the rest four arguments for a, w, p and t as double’s.
Whereas, if I made a call to the @cfunc wrapper directly, I only needed to handle the four dobule arguments for a, w, p and t. The pointers to the %retptr and the %excinf were handled internally
; Function Attrs: mustprogress nofree nosync nounwind willreturn writeonly
define double @cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double %.1, double %.2, double %.3, double %.4) local_unnamed_addr #2 {
entry:
%.6 = alloca double, align 8
store double 0.000000e+00, double* %.6, align 8
%.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
%.20 = load double, double* %.6, align 8
ret double %.20
}
To expand a bit on this @cfunc wrapper,
%.6 = alloca double, align 8
store double 0.000000e+00, double* %.6, align 8
is where we allocate the double output (8-byte aligned) and store the value 0.000000e+00; assign it variable %.6 Then we have
%.10 = call i32 @_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd(double* nonnull %.6, { i8*, i32, i8*, i8*, i32 }** nonnull poison, double %.1, double %.2, double %.3, double %.4) #3
where we pass the pointer of %.6 as the first argument, a poison or dummy pointer for the exception info block { i8*, i32, i8*, i8*, i32 }** and afterwards the four double’s for the variables a, w, p and t.
Now I have two solutions
- Make a call to the
@cfuncwrapper, which is the most straightforward one. - Define a
dobulepointer and a dummy pointer for theexcinfopart and pass them as the arguments of Numba’s nopython ABI implementation.
Solution 1
julia> using Libdl
julia> using Enzyme
julia> const lib = Libdl.dlopen("./libfunc.so")
Ptr{Nothing} @0x0000000009d707f0
julia> const f_ptr = Libdl.dlsym(
lib, Symbol(
"cfunc._ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd"
)
)
Ptr{Nothing} @0x00007fcc3aa21140
julia> f(a, w, p, t) = ccall(f_ptr, Float64, (Float64, Float64, Float64, Float64), a, w, p, t)
f (generic function with 1 method)
julia> f(1.0, 1.0, 1.0, 1.0)
-0.4161468365471424
julia> gradient(Reverse, f, 1.0, Const(1.0), Const(1.0), Const(1.0))
(-2.88904333752103e-310, nothing, nothing, nothing)
This returns a wrong result for the gradient with respect to the amplitude f(t; a, w, p) = a \cos(w t + p) \rightarrow \partial f/ \partial a = \cos(wt + p)
Solution 2
julia> using Libdl
julia> using Enzyme
julia> const lib = Libdl.dlopen("./libfunc.so")
Ptr{Nothing} @0x0000000014c181a0
julia> const g_ptr = Libdl.dlsym(
lib, :_ZN8__main__4funcB2v1B52c8tJTIeFIjxB2IKSgI4CrvQClUYkACQB1EiFSRRB9GgCAA_3d_3dEdddd)
Ptr{Nothing} @0x00007ff1a04eb100
julia> function g(a::Cdouble, w::Cdouble, p::Cdouble, t::Cdouble)
result = Ref{Cdouble}()
exc_info = Ref{Ptr{Nothing}}()
status = ccall(g_ptr, Cint,
(Ptr{Cdouble}, Ptr{Ptr{Nothing}}, Cdouble, Cdouble, Cdouble, Cdouble),
result, exc_info, a, w, p, t)
status == 0 || error("Python exception raised!")
result[]
end
g (generic function with 1 method)
julia> g(1.0, 1.0, 1.0, 1.0)
-0.4161468365471424
julia> gradient(Reverse, g, 1.0, Const(1.0), Const(1.0), Const(1.0))
(-0.4161468365471424, nothing, nothing, nothing)
which returns the correct gradient!
So, thanks a lot for helping me out @wsmoses really appreciate your patience and showing interest in this. And one last favour, if you could tell me why the gradient for the @cfunc wrapper is wrong.