Help with Flux.jl, Metal.jl (Apple Silicon) and Conv layers

after trying to use Flux.jl to reimplement some PyTorch code (SpinalVGG), I encountered some difficulties when trying to run my code on an M2 Max’s GPU with Metal.jl. The error boils down to the following code, which first produces a warning (which is btw. identical to the warning that I get when I attempt to run this model on a 2070 Super) and then ends with an error.

Code (minimal):

using Flux, Metal
layer = gpu(Conv((3, 3), 1=>16, relu))
x = gpu(rand(28, 28, 1, 1, 1))

Warning (beginning of):

┌ Warning: Performing scalar indexing on task Task (runnable) @0x00000003671e8c90.
│ Invocation of getindex resulted in scalar indexing of a GPU array.


        nested task error: ArgumentError: cannot take the CPU address of a MtlArray{Float32, 5, Metal.MTL.MTLResourceStorageModePrivate}

Based on the above error I suspect that the Conv layer (or its parameters) are not transferred to the GPU with Metal.jl. But I am even more puzzled by the warning (which persists even when running this with CUDA instead of Metal on a Nvidia GPU). What am I doing wrong/misunderstanding and is Conv layer not supported by Metal.jl?

I would be glad if anyone could help me solve this issue. I have tried running this with Julia 1.9 (and 1.10 beta3) with newly updated packages.

I get a different warning

julia> using Flux, Metal

julia> layer = gpu(Conv((3, 3), 1=>16, relu))
┌ Info: The CUDA functionality is being called but
│ `CUDA.jl` must be loaded to access it.
└ Add `using CUDA` or `import CUDA` to your code.
Conv((3, 3), 1 => 16, relu)  # 160 parameters

and an error

julia> layer(x)
ERROR: DimensionMismatch: Rank of x and w must match! (5 vs. 4)

Flux does have the Meta backend

julia> device = Flux.get_device(; verbose=true)
[ Info: Using backend: Metal.
(::Flux.FluxMetalDevice) (generic function with 1 method)

julia> device.deviceID
<AGXG14CDevice: 0x125850400>
    name = Apple M2 Max

Per Missing functionalities - Metal with Conv and ConvTranspose layers · Issue #2278 · FluxML/Flux.jl · GitHub, convolutions are not supported on the Metal backend because nobody has stepped up to wrap the relevant functionality in Julia.