Why is Flux/NNlib falling back to im2col instead of MIOpen on AMD GPU?

Hello! I’ve been trying to get Flux’s convolutional layers, specifically the Flux.NNlib.conv! function, to use the MIOpen version rather than im2col or direct. I’ve verified that I have MIOpen available with AMDGPU.functional(:MIOpen), and querying the version info from AMDGPU shows all of the necessary libraries, aside from rocFFT, are available. Julia recognizes my GPU, an AMD Ryzen RX 7900 XT, as well.

I’ve included the code below. Has anyone else run into this issue?

using Flux, AMDGPU

ct = Conv((3,3), 4 => 32, pad=1, stride=1)
r_ct = roc(ct)
x = ROCArray(rand(Float32, 40, 40, 4, 1))

r_ct(x)

This always errors due to scalar indexing, because Flux is not calling the MIOpen compatible convolution function.

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151
  [2] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124
  [3] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112
  [4] getindex
    @ ~/.julia/packages/GPUArrays/3a5jB/src/host/indexing.jl:50 [inlined]
  [5] scalar_getindex
    @ ~/.julia/packages/GPUArrays/3a5jB/src/host/indexing.jl:36 [inlined]
  [6] _getindex
    @ ~/.julia/packages/GPUArrays/3a5jB/src/host/indexing.jl:19 [inlined]
  [7] getindex
    @ ~/.julia/packages/GPUArrays/3a5jB/src/host/indexing.jl:17 [inlined]
  [8] getindex
    @ ./subarray.jl:316 [inlined]
  [9] im2col!(col::ROCArray{Float32, 2, AMDGPU.Runtime.Mem.HIPBuffer}, x::SubArray{Float32, 4, ROCArray{…}, Tuple{…}, true}, cdims::DenseConvDims{3, 3, 3, 6, 3})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/impl/conv_im2col.jl:253
 [10] (::NNlib.var"#conv_part#538"{ROCArray{…}, Float32, Float32, SubArray{…}, SubArray{…}, ROCArray{…}, DenseConvDims{…}, Int64, Int64, Int64})(task_n::Int64, part::UnitRange{Int64})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/impl/conv_im2col.jl:53
 [11] conv_im2col!(y::SubArray{…}, x::SubArray{…}, w::ROCArray{…}, cdims::DenseConvDims{…}; col::ROCArray{…}, alpha::Float32, beta::Float32, ntasks::Int64)
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/impl/conv_im2col.jl:69
 [12] conv_im2col!(y::SubArray{…}, x::SubArray{…}, w::ROCArray{…}, cdims::DenseConvDims{…})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/impl/conv_im2col.jl:23
 [13] (::NNlib.var"#conv_group#186"{@Kwargs{}, ROCArray{…}, ROCArray{…}, ROCArray{…}, DenseConvDims{…}})(xc::UnitRange{Int64}, wc::UnitRange{Int64})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/conv.jl:209
 [14] conv!(out::ROCArray{…}, in1::ROCArray{…}, in2::ROCArray{…}, cdims::DenseConvDims{…}; kwargs::@Kwargs{})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/conv.jl:218
 [15] conv!
    @ ~/.julia/packages/NNlib/srXYX/src/conv.jl:185 [inlined]
 [16] #conv!#143
    @ ~/.julia/packages/NNlib/srXYX/src/conv.jl:145 [inlined]
 [17] conv!
    @ ~/.julia/packages/NNlib/srXYX/src/conv.jl:140 [inlined]
 [18] conv(x::ROCArray{Float32, 4, AMDGPU.Runtime.Mem.HIPBuffer}, w::ROCArray{Float32, 4, AMDGPU.Runtime.Mem.HIPBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::@Kwargs{})
    @ NNlib ~/.julia/packages/NNlib/srXYX/src/conv.jl:88
 [19] conv
    @ ~/.julia/packages/NNlib/srXYX/src/conv.jl:83 [inlined]
 [20] (::Conv{2, 4, typeof(identity), ROCArray{Float32, 4, AMDGPU.Runtime.Mem.HIPBuffer}, ROCArray{Float32, 1, AMDGPU.Runtime.Mem.HIPBuffer}})(x::ROCArray{Float32, 4, AMDGPU.Runtime.Mem.HIPBuffer})
    @ Flux ~/.julia/packages/Flux/DZYiO/src/layers/conv.jl:201
 [21] top-level scope
    @ REPL[30]:1