Matrix multiplication fails with Reactant and Lux

Hello,

I have this simple script:

using LinearAlgebra
using Reactant
using Lux

dev = reactant_device()
#dev = cpu_device()

A = dev(rand(Float32, 2, 2))
B = dev(rand(Float32, 2, 3))
A * B  # ERROR: conversion to pointer not defined for ConcretePJRTArray...

which unfortunately return the error

ERROR: conversion to pointer not defined for ConcretePJRTArray{Float32, 2, 1}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] unsafe_convert(::Type{Ptr{Float32}}, a::ConcretePJRTArray{Float32, 2, 1})
   @ Base ./pointer.jl:67
 [3] gemm!(transA::Char, transB::Char, alpha::Float32, A::ConcretePJRTArray{…}, B::ConcretePJRTArray{…}, beta::Float32, C::ConcretePJRTArray{…})
   @ LinearAlgebra.BLAS ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:1524
 [4] gemm_wrapper!(C::ConcretePJRTArray{…}, tA::Char, tB::Char, A::ConcretePJRTArray{…}, B::ConcretePJRTArray{…}, _add::LinearAlgebra.MulAddMul{…})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:605
 [5] generic_matmatmul!
   @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:352 [inlined]
 [6] mul!
   @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:263 [inlined]
 [7] mul!
   @ ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
 [8] *(A::ConcretePJRTArray{Float32, 2, 1}, B::ConcretePJRTArray{Float32, 2, 1})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.10+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:113

I am on julia 1.10.10, with the following package versions
Lux v1.27.1
Reactant v0.2.185

The following operations work:

A.+A
sum(A)

but not matrix multiplication, which is quite crucial for training neural networks. May I be missing a package or extension?

CUDA info: Driver Version: 535.274.02 CUDA Version: 12.2

Many thanks for your help,

Michael

you need to compile the function first

fn = @compile A * B

fn(A, B)

Thank you for the quick response