Accelerating calling a Julia function from Python via juliacall and ctypes

I was creating an example of using a Julia function from Python via pyjuliacall, but I noticed the overhead was quite high. Below I outline how to reduce the overhead of calling the Julia function from Python by using Julia’s @cfunction and Python’s ctypes.

The objective is to call the following Julia function from Python efficiently.

sumsquares(v) = sum(x->x^2, v)

In Julia, I get the following benchmarks on my rather slow holiday computer.

julia> using BenchmarkTools

julia> v = rand(100_000);

julia> @benchmark sumsquares($v)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  31.792 μs … 116.156 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     32.068 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   32.393 μs ±   1.664 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▃▇█▇▁▂▁         ▁▁▁                                          ▁
  ███████▇▇▇▇▇█▇███████▇▇▅▄▃▄▄▃▅▆▅▅▅▆▅▅▅▅▆▆▅▆▅▅▄▅▅▄▅▅▄▄▅▂▄▅▄▅▄ █
  31.8 μs       Histogram: log(frequency) by time      38.4 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

To create and activate a conda environment, I invoked the following commands. Julia 1.11.2 was previously installed via juliaup.

mamba create -n juliacall_test python numpy ipython pyjuliacall
mamba activate juliacall_test

I get an environment with Python 3.13.1 and NumPy 2.2.1.

`mamba list` output
mamba list
# packages in environment at /private/conda/3/x86_64/envs/juliacall_test:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
asttokens                 3.0.0              pyhd8ed1ab_1    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2024.12.14           hbcca054_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_1    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_1    conda-forge
executing                 2.1.0              pyhd8ed1ab_1    conda-forge
ipython                   8.31.0             pyh707e725_0    conda-forge
jedi                      0.19.2             pyhd8ed1ab_1    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
libblas                   3.9.0           26_linux64_openblas    conda-forge
libcblas                  3.9.0           26_linux64_openblas    conda-forge
libexpat                  2.6.4                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
liblapack                 3.9.0           26_linux64_openblas    conda-forge
liblzma                   5.6.3                hb9d3cd8_1    conda-forge
libmpdec                  4.0.0                h4bc722e_0    conda-forge
libopenblas               0.3.28          pthreads_h94d23a6_1    conda-forge
libsqlite                 3.47.2               hee588c1_0    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_1    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
numpy                     2.2.1           py313hb30382a_0    conda-forge
openssl                   3.4.0                hb9d3cd8_0    conda-forge
parso                     0.8.4              pyhd8ed1ab_1    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_1    conda-forge
pickleshare               0.7.5           pyhd8ed1ab_1004    conda-forge
pip                       24.3.1             pyh145f28c_2    conda-forge
prompt-toolkit            3.0.48             pyha770c72_1    conda-forge
ptyprocess                0.7.0              pyhd8ed1ab_1    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_1    conda-forge
pygments                  2.18.0             pyhd8ed1ab_1    conda-forge
pyjuliacall               0.9.23             pyhd8ed1ab_1    conda-forge
pyjuliapkg                0.1.15             pyhd8ed1ab_0    conda-forge
python                    3.13.1          ha99a958_102_cp313    conda-forge
python_abi                3.13                    5_cp313    conda-forge
readline                  8.2                  h8228510_1    conda-forge
semver                    3.0.2              pyhd8ed1ab_1    conda-forge
stack_data                0.6.3              pyhd8ed1ab_1    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_1    conda-forge
typing_extensions         4.12.2             pyha770c72_1    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_1    conda-forge
import ctypes
import numpy as np
from juliacall import Main as jl

# Julia setup
jl.seval("sumsquares(v) = sum(x->x^2, v)")
jl.seval("sumsquares(v::Ptr{Float64}, len::Int) = sumsquares(unsafe_wrap(Array, v, len; own = false))")
p = jl.seval("Int(@cfunction(sumsquares, Float64, (Ptr{Float64}, Int)))")
FUNCTYPE = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.POINTER(ctypes.c_double), ctypes.c_int64)
jl_sumsquares = FUNCTYPE(p)

# Timing setup
v = np.random.rand(100_000)
# Call Julia's sumsquares via juliacall
jl.sumsquares(v)
# Call Julia's sumsquares via ctypes
jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), len(v))

If I time this from Python via juliacall’s standard invocation, I get the following results. 138 μs suggests there is considerable overhead (106 μs) versus the 32 μs seen in Julia itself. The objective is to improve on this.

In [6]: %timeit jl.sumsquares(v)
138 μs ± 2.88 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

If I use ctypes to invoke the Julia function via a C pointer, the overhead is reduced from 106 μs to 8 μs.

In [9]: %timeit jl_sumsquares(v.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
   ...:  len(v))
39.6 μs ± 86.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The total time is twice as fast my quick implementation via NumPy. I realize that this is not a one-to-one comparison.

In [12]: def sumsquares_numpy(v):
    ...:     return np.sum(v**2)
    ...: 

In [13]: %timeit sumsquares_numpy(v)
80.6 μs ± 220 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

As you might expect, a pure Python version is much slower. The timing is in milliseconds rather than microseconds.

In [14]: def sumsquares_purepython(v):
    ...:     return sum(map(lambda x: x**2, v))
    ...: 

In [15]: %timeit sumsquares_purepython(v)
28.3 ms ± 224 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In summary, by using Python’s ctypes to call a C function pointer, I was able to reduce the time to call a Julia function from Python from 138 μs to 40 μs. In comparison, benchmarking the function in Julia itself takes 32 μs. Thus it appears using ctypes and @cfunction can reduce the overhead of calling a Julia function via Python from 106 μs to 8 μs.

edit: I posted a self-contained Python script as a Github gist.

2 Likes

Maybe that’ s why GitHub - Suzhou-Tongyuan/jnumpy: Writing Python C extensions in Julia within 5 minutes. exists.

1 Like

Just to double-check, does %timeit jl.sumsquares(v) also execute the property access jl.sumsquares per loop? Could try jlsumsquares = jl.sumsquares and call that instead to see.

Also curious what various things account for juliacall’s interop overhead compared to calling a C function in Python or C code, but that may be off-topic.

I modified the Gist with the following Python code.

juliacall_sumsquares = jl.sumsquares
juliacall_time2 = timeit(
    "juliacall_sumsquares(v)",
    "from __main__ import v, juliacall_sumsquares",
    number=number_of_calls
)
print(f"juliacall_sumsquares(v) in Python via juliacall took {juliacall_time2/number_of_calls*10**6} microseconds")
sumsquares(vcopy) in Julia took 31.714999999999996 microseconds
jl.sumsquares(v) in Python via juliacall took 136.60029400125495 microseconds
juliacall_sumsquares(v) in Python via juliacall took 131.42659200093476 microseconds
jl_sumsquares(v) in Python via ctypes took 40.03097199893091 microseconds
sumsquares_numpy(v) in Python via numpy took 94.68027999901096 microseconds
sumsquares_purepython(v) in pure Python took 27943.604591997428 microseconds

Removing the property reference seems to reduce the overhead by about 4 - 5 microseconds.

1 Like

To put the timings in perspective, sumsquares(v) takes 0.3ns per entry when called from Julia and 1ns per entry when called from Python. These are both in the realm of “really fast” (i.e. compiled machine code fast) and I suspect the main difference is SIMD kicking in for the former case.

In Julia v is an Array whereas in Python it gets passed as a PyArray. A PyArray is intended to be as fast as a Array but clearly not - maybe there is some method missing that would let Julia SIMD your function.

Perhaps there is a hint here? What does nocapture mean in the LLVM IR?

In [80]: jl.seval("@code_llvm sumsquares(vcopy::Array{Float64})")
; Function Signature: sumsquares(Array{Float64, 1})
;  @ none:1 within `sumsquares`
define double @julia_sumsquares_46981(ptr noundef nonnull align 8 dereferenceable(24) %"v::Array") #0 {
top:
; ┌ @ reducedim.jl:983 within `sum`
; │┌ @ reducedim.jl:983 within `#sum#934`
; ││┌ @ reducedim.jl:987 within `_sum`
; │││┌ @ reducedim.jl:987 within `#_sum#936`
; ││││┌ @ reducedim.jl:329 within `mapreduce`
; │││││┌ @ reducedim.jl:329 within `#mapreduce#926`
; ││││││┌ @ reducedim.jl:337 within `_mapreduce_dim`
         %0 = call double @j__mapreduce_46986(ptr nonnull %"v::Array")
         ret double %0
; └└└└└└└
}

In [81]: jl.seval("@code_llvm sumsquares(v::PyArray)")
; Function Signature: sumsquares(PythonCall.Wrap.PyArray{Float64, 1, true, true, Float64})
;  @ none:1 within `sumsquares`
define double @julia_sumsquares_46987(ptr nocapture noundef nonnull readonly align 8 dereferenceable(48) %"v::PyArray") #0 {
top:
; ┌ @ reducedim.jl:983 within `sum`
; │┌ @ reducedim.jl:983 within `#sum#934`
; ││┌ @ reducedim.jl:987 within `_sum`
; │││┌ @ reducedim.jl:987 within `#_sum#936`
; ││││┌ @ reducedim.jl:329 within `mapreduce`
; │││││┌ @ reducedim.jl:329 within `#mapreduce#926`
; ││││││┌ @ reducedim.jl:337 within `_mapreduce_dim`
         %0 = call double @j__mapreduce_46992(ptr nocapture nonnull readonly %"v::PyArray")
         ret double %0
; └└└└└└└
}

From the docs it means “a call to the function does not create any copies of the pointer value that outlive the call”. Unclear if this is relevant. The call stack looks otherwise the same in the two cases, can you get the LLVM code for _mapreduce_dim?

Another thought is that Array is dense whereas PyArray can be strided, which probably leads to more work when indexing.

If you try the initial example with a strided array like view(rand(1_000_000), 1:10:1_000_000) how does the performance compare?

Edit: Or even view(rand(100_000), 1:1:100_000) to get a possibly-strided (but actually dense) vector.