Using SpecialFunctions spherical bessel functions with Reactant

I’m trying to apply Reactant to my code which uses spherical Bessel functions from the SpecialFunctions package. I’m running into this error; has anyone been able to use spherical Bessel functions with Reactant? @avikpal

# Test if the spherical bessel functions I use play nice with Reactant
using Reactant
using SpecialFunctions

function foo(x)
    return sphericalbessely(x[1], x[2])
end

x = Reactant.to_rarray(rand(2))
@allowscalar @jit foo(x)
LoadError: MethodError: no method matching bessely(::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
The function `bessely` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  bessely(::Float32, ::ComplexF32)
   @ SpecialFunctions ~/.julia/packages/SpecialFunctions/mf0qH/src/bessel.jl:647
  bessely(::Float64, ::ComplexF64)
   @ SpecialFunctions ~/.julia/packages/SpecialFunctions/mf0qH/src/bessel.jl:449
  bessely(::Float16, ::ComplexF16)
   @ SpecialFunctions ~/.julia/packages/SpecialFunctions/mf0qH/src/bessel.jl:646
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(bessely), ::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:875
  [3] sphericalbessely
    @ ~/.julia/packages/SpecialFunctions/mf0qH/src/bessel.jl:782 [inlined]
  [4] (::Nothing)(none::typeof(sphericalbessely), none::Reactant.TracedRNumber{Float64}, none::Reactant.TracedRNumber{Float64})
    @ Reactant ./<missing>:0
  [5] TracedRNumber
    @ ~/.julia/packages/Reactant/EsmaI/src/TracedRNumber.jl:101 [inlined]
  [6] convert
    @ ./number.jl:7 [inlined]
  [7] zero
    @ ./number.jl:309 [inlined]
  [8] float
    @ ./float.jl:391 [inlined]
  [9] sphericalbessely
    @ ~/.julia/packages/SpecialFunctions/mf0qH/src/bessel.jl:782 [inlined]
 [10] call_with_reactant(::Reactant.MustThrowError, ::typeof(sphericalbessely), ::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [11] foo
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:6 [inlined]
 [12] (::Nothing)(none::typeof(foo), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [13] foo
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:6 [inlined]
 [14] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
 [15] make_mlir_fn(f::typeof(foo), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/EsmaI/src/TracedUtils.jl:332
 [16] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1555
 [17] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1522 [inlined]
 [18] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3433
 [19] compile_xla
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3406 [inlined]
 [20] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3505
 [21] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:2586 [inlined]
 [22] top-level scope
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:206
 [23] include(fname::String)
    @ Main ./sysimg.jl:38
 [24] top-level scope
    @ REPL[2]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:10
Some type information was truncated. Use `show(err)` to see complete types.

Bessel functions in SpecialFunctions.jl eventually call functions in a C library, I’m not sure Reactant.jl supports @ccalls at all.

I think we just need to add a definition for it. Either in julia side [by decomposing it to other primitives]. or alternatively we can add one MLIR-side (which also might merit additional optimizations/more stable autodiff/etc)

I found Bessels.jl as an alternative package which is pure Julia. I seem to run into the same issue:

# Test if the spherical bessel functions I use play nice with Reactant
using Reactant
using Bessels
using SpecialFunctions
using BenchmarkTools

function foo(x)
    return Bessels.sphericalbessely(x[1], x[2])
end

function bar(x)
    return SpecialFunctions.sphericalbessely(x[1], x[2])
end

x = Reactant.to_rarray(rand(2))
@allowscalar @jit foo(x)
LoadError: MethodError: no method matching sphericalbessely(::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
The function `sphericalbessely` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  sphericalbessely(::Real, ::Real)
   @ Bessels ~/.julia/packages/Bessels/eaWGd/src/sphericalbessel.jl:153

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::typeof(sphericalbessely), ::Reactant.TracedRNumber{Float64}, ::Reactant.TracedRNumber{Float64})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:875
  [3] foo
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:8 [inlined]
  [4] (::Nothing)(none::typeof(foo), none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
  [5] foo
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:8 [inlined]
  [6] call_with_reactant(::typeof(foo), ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/EsmaI/src/utils.jl:0
  [7] make_mlir_fn(f::typeof(foo), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/EsmaI/src/TracedUtils.jl:332
  [8] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1555
  [9] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:1522 [inlined]
 [10] compile_xla(f::Function, args::Tuple{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3433
 [11] compile_xla
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3406 [inlined]
 [12] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:3505
 [13] macro expansion
    @ ~/.julia/packages/Reactant/EsmaI/src/Compiler.jl:2586 [inlined]
 [14] top-level scope
    @ ~/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:206
 [15] include(fname::String)
    @ Main ./sysimg.jl:38
 [16] top-level scope
    @ REPL[2]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/bessel_react.jl:20
Some type information was truncated. Use `show(err)` to see complete types.