How to Efficiently Index PyArrays?

I have been trying to rewrite a python script I wrote a while back in julia to increase its performance. The python script used a package called gym-retro which ran a videogame emulator that the script could directly interact with. As far as I am aware, there is no Julia equivalent so I am importing gym-retro using PyCall. Game information such as the current frame and RAM are outputted as numpy arrays. My attempts to read the data from these arrays have been incredibly slow and memory inefficient. For example, if I run this code:

function index_test(pyarray::PyArray{UInt8, 3})
    a = pyarray[1, 1, 1]
end


function conversion_test(pyarray::PyArray{UInt8, 3})
    converted_frame = convert(Array{UInt8, 3}, pyarray)
end


function reset(arcade::Arcade)
    frame = pycall(arcade.env.reset, PyArray)
    println(typeof(frame))
    println(size(frame))
    @btime index_test($frame)
    @btime conversion_test($frame)
end

It outputs this response in the REPL:

PyArray{UInt8, 3}
(224, 240, 3)
  308.230 ns (14 allocations: 416 bytes)
  184.150 ms (5653541 allocations: 182.49 MiB)

I am unsure as to why indexing the array causes 14 allocations. This is bottle-necking my program as I can only process a maximum of 5 frames per second. Any advice on how I can more efficiently index data from a PyArray would be greatly appreciated.

EDIT: Here are the results if I use the view() function:

function view_test(pyarray::PyArray{UInt8, 3})
    a = view(pyarray, 1, 1, 1)
end


function reset(arcade::Arcade)
    frame = pycall(arcade.env.reset, PyArray)
    @btime view_test($frame)
end
  806.977 ns (16 allocations: 336 bytes)

Here is a MWE:

using BenchmarkTools
using PyCall

function index_test(pyarray::PyArray{UInt8, 3})
    s = 0
    for i in 1:size(array, 1)
        for j in 1:size(array, 2)
            for k in 1:size(array, 3)
                s += array[i, j, k]
            end
        end
    end
    s
end

function from_python(pyarray::PyArray{UInt8, 3})
    return convert(Array{UInt8, 3}, pyarray)
end

function index_test(array::Array{UInt8, 3})
    s = 0
    for i in 1:size(array, 1)
        for j in 1:size(array, 2)
            for k in 1:size(array, 3)
                s += array[i, j, k]
            end
        end
    end
    s
end

function to_python(array::Array{UInt8, 3})
    PyArray(PyObject(array))
end

array = Array{UInt8, 3}(undef, (250, 250, 3))
@btime index_test($array)
@btime to_python($array)
pyarray = PyArray(PyObject(array))
@btime index_test($pyarray)
@btime from_python($pyarray)
println()

with results

 89.900 ΞΌs (0 allocations: 0 bytes)
  1.440 ΞΌs (20 allocations: 1.16 KiB)
  19.330 ms (500491 allocations: 12.41 MiB)
  444.500 ΞΌs (2 allocations: 183.20 KiB)

So indeed, indexing into PyArray is quite slow as is the backward conversion. But I ask myself if this is really necessary, because the documentation states

Assuming you have NumPy installed (true by default if you use Conda), then a Julia a::Array of NumPy-compatible elements is converted by PyObject(a) into a NumPy wrapper for the same data, i.e. without copying the data.

which I read as: if you call Python with PyObject(array)) it still references array, i.e. changes in Python are immediately reflected in array. So you should be able to use Julia indexing on the passed array after the call.

Good question … as far as I can tell the getindex operation should completely inline to a direct pointer dereference (an unsafe_load on a cached pointer), and hence should be non-allocating and nearly as fast as native Julia array accesses: PyCall.jl/pyarray.jl at 8a98fb45ef39d09e11c1270e9120a32df2578d50 Β· JuliaPy/PyCall.jl Β· GitHub

However, I haven’t looked at this code (written by @JobJob) for quite some time, so it’s possible that some optimization is glitching. It would be worth taking a look at the generated code.

I did a few tests earlier today and found that there are fewer allocations if the PyArray is two dimensional. I needed to convert the RGB image into grayscale anyways so flattening it into 2d with state = pycall(np.sum, PyArray, state, axis=2) and then indexing it brings the conversion down from 184ms to 10ms. Still a bit longer than I would like but it is definitely a big improvement.

function test(pyarray::PyArray{UInt32, 2})
    a = pyarray[1, 1]
end


function main(arcade::Arcade)
    frame = pycall(arcade.env.reset, PyArray)
    frame = pycall(np.sum, PyArray, frame, axis=2)
    println(typeof(frame))
    println(size(frame))
    @btime test($frame)
end
PyArray{UInt32, 2}
(224, 240)
  218.672 ns (9 allocations: 272 bytes)

Could you elaborate a bit more on how to pass a Julia array into numpy?

Just checked my understanding of the documentation: we pass arrays by reference to python. Modified example from this thread

using PyCall

py"""
import numpy
xs = numpy.zeros((2, 2))
"""

xs = PyArray(py"xs"o)
println(xs)

py"""
import numpy
xs[0, 0] = 1
"""

println(xs)

resolves to

[0.0 0.0; 0.0 0.0]
[1.0 0.0; 0.0 0.0]

But of course you are out of luck if Python generates the array.

Oh I see. Thank you for the clarification. Unfortunately Python does generate my array. :frowning:

Here is the llvm of the following code:

function test(pyarray::PyArray{UInt8, 3})
    a = pyarray[1, 1, 1]
end


function main(arcade::Arcade)
    frame = pycall(arcade.env.reset, PyArray)
    println(typeof(frame))
    println(size(frame))
    @code_llvm test(frame)
end
PyArray{UInt8, 3}
(224, 240, 3)

;  @ C:\Users\user\PycharmProjects\juliaProject\arcade.jl:145 within `test'
; Function Attrs: uwtable
define i8 @julia_test_3329({}* nonnull align 8 dereferenceable(80) %0) #0 {
top:
;  @ C:\Users\user\PycharmProjects\juliaProject\arcade.jl:146 within `test'
  %1 = call i8 @j_getindex_3331({}* nonnull %0, i64 signext 1, i64 signext 1, i64 signext 1) #0
  ret i8 %1
}

Here is the assembly code:

	.text
; β”Œ @ arcade.jl:145 within `test'
	pushq	%rbp
	movq	%rsp, %rbp
	subq	$32, %rsp
; β”‚ @ arcade.jl:146 within `test'
	movabsq	$getindex, %rax
	movl	$1, %edx
	movl	$1, %r8d
	movl	$1, %r9d
	callq	*%rax
	addq	$32, %rsp
	popq	%rbp
	retq
	nopl	(%rax,%rax)
; β””

Just for completeness sake: I profiled the MWE. from_python looks OK to me. index_test shows this picture:

So I can’t even say which PyCall function is the culprit.

So your option would be to convert first and then index in Julia? At least you should get more than 5 frames this way.

Thank you for your help. As I need to convert to grayscale, crop, and downsample the frame the fastest implementation I could find was the following:

function format_state(state::PyArray{UInt32, 2})
    new_state = zeros(Float32, (92, 120))
    @simd for i in 1:92
        i2 = 2 * i + 25
        @simd for j in 1:120
            j2 = 2 * j
            new_state[i, j] = (Float32(state[i2-1, j2-1]) + Float32(state[i2, j2-1]) +
                               Float32(state[i2-1, j2]) + Float32(state[i2, j2])) / 3060.0
        end
    end
    return new_state
end


function main(arcade::Arcade)
    state = pycall(arcade.env.reset, PyArray)
    state = pycall(np.sum, PyArray, state, axis=2)
    formatted_state = format_state(state)
end

The full conversion of a single frame is just under 10 ms so I can now process 100 frames per second.

No, that is what PyArray is for β€” it exposes a copy-free AbstractArray wrapper for a numpy array.

1 Like

MWE

using PyCall

function index_test(pyarray::PyArray{UInt8, 3})
    s = 0
    for i in 1:size(array, 1)
        for j in 1:size(array, 2)
            for k in 1:size(array, 3)
                s += array[i, j, k]
            end
        end
    end
    s
end

array = Array{UInt8, 3}(undef, (250, 250, 3))
pyarray = PyArray(PyObject(array))
@code_warntype index_test(pyarray)

outputs

Variables
  #self#::Core.Const(index_test)
  pyarray::PyArray{UInt8, 3}
  @_3::Any
  s::Any
  @_5::Any
  i::Any
  @_7::Any
  j::Any
  k::Any

Body::Any
1 ──       (s = 0)
β”‚    %2  = Main.size(Main.array, 1)::Any
β”‚    %3  = (1:%2)::Any
β”‚          (@_3 = Base.iterate(%3))
β”‚    %5  = (@_3 === nothing)::Bool
β”‚    %6  = Base.not_int(%5)::Bool
└───       goto #10 if not %6
2 ┄─ %8  = @_3::Any
β”‚          (i = Core.getfield(%8, 1))
β”‚    %10 = Core.getfield(%8, 2)::Any
β”‚    %11 = Main.size(Main.array, 2)::Any
β”‚    %12 = (1:%11)::Any
β”‚          (@_5 = Base.iterate(%12))
β”‚    %14 = (@_5 === nothing)::Bool
β”‚    %15 = Base.not_int(%14)::Bool
└───       goto #8 if not %15
3 ┄─ %17 = @_5::Any
β”‚          (j = Core.getfield(%17, 1))
β”‚    %19 = Core.getfield(%17, 2)::Any
β”‚    %20 = Main.size(Main.array, 3)::Any
β”‚    %21 = (1:%20)::Any
β”‚          (@_7 = Base.iterate(%21))
β”‚    %23 = (@_7 === nothing)::Bool
β”‚    %24 = Base.not_int(%23)::Bool
└───       goto #6 if not %24
4 ┄─ %26 = @_7::Any
β”‚          (k = Core.getfield(%26, 1))
β”‚    %28 = Core.getfield(%26, 2)::Any
β”‚    %29 = s::Any
β”‚    %30 = Base.getindex(Main.array, i, j, k)::Any
β”‚          (s = %29 + %30)
β”‚          (@_7 = Base.iterate(%21, %28))
β”‚    %33 = (@_7 === nothing)::Bool
β”‚    %34 = Base.not_int(%33)::Bool
└───       goto #6 if not %34
5 ──       goto #4
6 ┄─       (@_5 = Base.iterate(%12, %19))
β”‚    %38 = (@_5 === nothing)::Bool
β”‚    %39 = Base.not_int(%38)::Bool
└───       goto #8 if not %39
7 ──       goto #3
8 ┄─       (@_3 = Base.iterate(%3, %10))
β”‚    %43 = (@_3 === nothing)::Bool
β”‚    %44 = Base.not_int(%43)::Bool
└───       goto #10 if not %44
9 ──       goto #2
10 β”„       return s

Not helpful since it looks like getindex is not inlined, so you can’t see what it is doing. Maybe add @inbounds?

Sure thing

function test(pyarray::PyArray{UInt8, 3})
    a = pyarray[1, 1, 1]
end


function main(arcade::Arcade)
    state = pycall(arcade.env.reset, PyArray)
    @code_native @inbounds test(state)
end
	.text
	.file	"@inbounds"
	.globl	julia_@inbounds_2913            # -- Begin function julia_@inbounds_2913
	.p2align	4, 0x90
	.type	julia_@inbounds_2913,@function
julia_@inbounds_2913:                   # @"julia_@inbounds_2913"
	.cfi_startproc
# %bb.0:                                # %top
	pushq	%rbp
	.cfi_def_cfa_offset 16
	.cfi_offset %rbp, -16
	movq	%rsp, %rbp
	.cfi_def_cfa_register %rbp
	pushq	%r14
	pushq	%rsi
	pushq	%rdi
	pushq	%rbx
	andq	$-32, %rsp
	subq	$160, %rsp
	.cfi_offset %rbx, -48
	.cfi_offset %rdi, -40
	.cfi_offset %rsi, -32
	.cfi_offset %r14, -24
	movq	%r8, %rsi
	vxorps	%xmm0, %xmm0, %xmm0
	vmovaps	%ymm0, 96(%rsp)
	movq	$0, 128(%rsp)
	movl	$41024112, %eax                 # imm = 0x271FA70
	vzeroupper
	callq	*%rax
	movq	%rax, %r14
	movq	$12, 96(%rsp)
	movq	(%r14), %rax
	movq	%rax, 104(%rsp)
	leaq	96(%rsp), %rax
	movq	%rax, (%r14)
	movq	$243776616, 56(%rsp)            # imm = 0xE87BC68
	movq	$1802904512, 64(%rsp)           # imm = 0x6B7623C0
	movabsq	$j1_Expr_2915, %rax
	leaq	56(%rsp), %rdi
	movl	$1802224016, %ecx               # imm = 0x6B6BC190
	movq	%rdi, %rdx
	movl	$2, %r8d
	callq	*%rax
	movq	%rax, %rbx
	movq	%rbx, 120(%rsp)
	movq	%rsi, 56(%rsp)
	movabsq	$j1_esc_2916, %rax
	movl	$1832004368, %ecx               # imm = 0x6D322B10
	movq	%rdi, %rdx
	movl	$1, %r8d
	callq	*%rax
	movq	%rax, 112(%rsp)
	movq	$243260656, 56(%rsp)            # imm = 0xE7FDCF0
	movq	$243314000, 64(%rsp)            # imm = 0xE80AD50
	movq	%rax, 72(%rsp)
	movabsq	$j1_Expr_2917, %rax
	movl	$1802224016, %ecx               # imm = 0x6B6BC190
	movq	%rdi, %rdx
	movl	$3, %r8d
	callq	*%rax
	movq	%rax, 112(%rsp)
	movq	$243838488, 56(%rsp)            # imm = 0xE88AE18
	movq	%rax, 64(%rsp)
	movabsq	$j1_Expr_2918, %rax
	movl	$1802224016, %ecx               # imm = 0x6B6BC190
	movq	%rdi, %rdx
	movl	$2, %r8d
	callq	*%rax
	movq	%rax, %rsi
	movq	%rsi, 128(%rsp)
	movq	$243776616, 56(%rsp)            # imm = 0xE87BC68
	movq	$244007064, 64(%rsp)            # imm = 0xE8B4098
	movabsq	$j1_Expr_2919, %rax
	movl	$1802224016, %ecx               # imm = 0x6B6BC190
	movq	%rdi, %rdx
	movl	$2, %r8d
	callq	*%rax
	movq	%rax, 112(%rsp)
	movq	$243513968, 56(%rsp)            # imm = 0xE83BA70
	movq	%rbx, 64(%rsp)
	movq	%rsi, 72(%rsp)
	movq	%rax, 80(%rsp)
	movq	$243314000, 88(%rsp)            # imm = 0xE80AD50
	movabsq	$j1_Expr_2920, %rax
	movl	$1802224016, %ecx               # imm = 0x6B6BC190
	movq	%rdi, %rdx
	movl	$5, %r8d
	callq	*%rax
	movq	104(%rsp), %rcx
	movq	%rcx, (%r14)
	leaq	-32(%rbp), %rsp
	popq	%rbx
	popq	%rdi
	popq	%rsi
	popq	%r14
	popq	%rbp
	retq
.Lfunc_end0:
	.size	julia_@inbounds_2913, .Lfunc_end0-julia_@inbounds_2913
	.cfi_endproc
                                        # -- End function
	.section	".note.GNU-stack","",@progbits

(@code_llvm or @code_lowered or @code_warntype are much more readable than @code_native) … I would also suggest putting @inbounds inside test rather than at the call site.

Sorry about that.

function test(pyarray::PyArray{UInt8, 3})
    @inbounds a = pyarray[1, 1, 1]
end


function main(arcade::Arcade)
    state = pycall(arcade.env.reset, PyArray)
    @code_llvm test(state)
end
;  @ C:\Users\user\PycharmProjects\juliaProject\arcade.jl:159 within `test'
; Function Attrs: uwtable
define i8 @julia_test_2994({}* nonnull align 8 dereferenceable(80) %0) #0 {
top:
;  @ C:\Users\user\PycharmProjects\juliaProject\arcade.jl:160 within `test'
  %1 = call i8 @j_getindex_2996({}* nonnull %0, i64 signext 1, i64 signext 1, i64 signext 1) #0
  ret i8 %1
}
"C:\Users\user\AppData\Local\Programs\Julia-1.6.3\bin\julia.exe" --check-bounds=no --history-file=no --inline=no --color=yes --math-mode=ieee --handle-signals=no --startup-file=yes --compile=yes --depwarn=yes --code-coverage=none --track-allocation=none "C:/Users/user/PycharmProjects/juliaProject/arcade.jl"
Variables
  #self#::Core.Const(test)
  pyarray::PyArray{UInt8, 3}
  val::UInt8
  a::UInt8

Body::UInt8
1 ─      $(Expr(:inbounds, true))
β”‚   %2 = Base.getindex(pyarray, 1, 1, 1)::UInt8
β”‚        (a = %2)
β”‚        (val = %2)
β”‚        $(Expr(:inbounds, :pop))
└──      return val

Since you overwrite this, it’s more efficient to do

new_state = Vector{Float32}(undef, 92, 120))

I did test this. I found the latter to actually be faster. :woman_shrugging:

That would be expected. But I’m getting the impression you mean the opposite, which would be highly surprising.