Mixing CUDA.jl with external GPU compute (OpenMM / DLPack.jl)

I am using OpenMM, a molecular dynamics package, through PyCall.
Now I want to combine my Julia integrator with the OpenMM force computation (on GPU).

To this end I tried using DLPack.jl together with openmm-dlext (many thanks to @PabloZubieta!) which if I understand correctly is passing the GPU memory (adresses) from OpenMM to Julia, such that I can access the OpenMM internal positions/forces through Julias CuArrays.

Very cool :slight_smile:

However, I am now stumbling into performance problems I do not understand.
In order to compute the forces (compforce) I launch OpenMM’s force calculation via a PyCall to OpenMM’s context.getState.

I then want to read out the forces, which requires some rescaling on the GPU (readforce)

Both operations for themselves are fast (@benchmark shows 40us and 10us respectively).
Sequentially however they are very slow (>300us), much slower than just copying the memory from GPU to CPU from within OpenMM/Python (~100us).

For the combined call the profiler shows all time spent in the pycall.

For context here is (the DLPack part of) my implementation

Code
using PyCall
import DLPack

struct DLForce7{S,T,P,F, FO}
    pysim::PyObject
    positions::CuArray{Float32, 2}
    forces::CuArray{Int, 2}  # yes, thats how the forces are passed
    forceout::CuArray{Float32, 2}
    scaling::Float32
end

function add_dlpack2(pysim)
    dlext = pyimport("openmm.dlext")
    cupy = pyimport("cupy")
    dlforce = dlext.Force()
    pysim.system.addForce(dlforce)
    dlview = dlforce.view(pysim.context)
    positions = DLPack.from_dlpack(cupy.from_dlpack(pycall(dlext.positions, PyObject, dlview), copy=false))
    forces = DLPack.from_dlpack(cupy.from_dlpack(pycall(dlext.forces, PyObject, dlview), copy=false))

    f2 = force(sim, coords(sim))
    scaling = Float32(f2[1] / collect(forces)[1])
    
    DLForce7(pysim, positions, forces, forces .* scaling, scaling)
end

coords(dl::DLForce7) = dl.positions
setcoords(dl::DLForce7, x) = (dl.positions  .= x)

compforce(dl::DLForce7) = pycall(dl.pysim.context.getState, Nothing, getForces=true)
readforce(dl::DLForce7) = (dl.forceout .*= dl.forces .* dl.scaling)

force(dl::DLForce7) = (compforce(dl); readforce(dl))
force(dl::DLForce7, x) = (setcoords(dl, x); force(dl))
Benchmarks
julia> @benchmark OpenMM.readforce(dl)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  10.606 ΞΌs … 142.563 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     11.056 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   11.592 ΞΌs Β±   2.548 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–β–†β–‡β–ˆβ–ˆβ–‡β–†β–…β–„β–„β–ƒβ–β–              ▁▂▂▃▃▂▃▂▁▁                        β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–…β–‡β–‡β–…β–†β–…β–…β–„β–…β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–„β–†β–…β–†β–†β–†β–†β–…β–†β–†β–†β–…β–†β–†β–…β–…β–… β–ˆ
  10.6 ΞΌs       Histogram: log(frequency) by time      16.3 ΞΌs <

 Memory estimate: 2.55 KiB, allocs estimate: 79.

julia> @benchmark OpenMM.compforce(dl)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  29.945 ΞΌs …  2.535 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     31.637 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   32.442 ΞΌs Β± 25.973 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

     β–β–„β–‡β–ˆβ–‡β–ˆβ–†β–„                                                  
  β–‚β–ƒβ–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–„β–„β–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–‚β–‚β–‚ β–ƒ
  29.9 ΞΌs         Histogram: frequency by time        42.5 ΞΌs <

 Memory estimate: 896 bytes, allocs estimate: 16.

julia> @benchmark (OpenMM.compforce(dl); OpenMM.readforce(dl))
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  293.398 ΞΌs …  3.462 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     323.793 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   323.201 ΞΌs Β± 33.031 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

                                  ▅▅▁      β–ƒβ–ˆβ–„                  
  β–‚β–β–β–β–β–β–β–β–β–β–β–β–‚β–‚β–β–β–β–‚β–β–‚β–β–‚β–‚β–‚β–β–‚β–‚β–‚β–‚β–‚β–‚β–†β–ˆβ–ˆβ–ˆβ–†β–„β–ƒβ–ƒβ–ƒβ–„β–ˆβ–ˆβ–ˆβ–‡β–„β–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚ β–ƒ
  293 ΞΌs          Histogram: frequency by time          338 ΞΌs <

 Memory estimate: 3.42 KiB, allocs estimate: 95.

Am I overseeing something?

Best, Alex