Are there alternative to Diagonal and eigvals compatible with Reactant?

Hello!

I’m in the process of refactoring a complicated function so that it can be compiled by Reactant. It has not been very straightforward, but the speedup I’ve seen so far certainly justifies the effort.

Nonetheless, I’ve become stuck at the last step, where I need to calculate something along the lines of the function f defined in the following snippet:

using Reactant, LinearAlgebra

function f(A, diag_entries)
    D = diagm(diag_entries)
    F = A' * D * A
    sum(inv, eigvals(F))
end

A = randn(10^4, 10^2)
diag_entries = randn(10^4)

args = A, diag_entries
rargs = map(Reactant.ConcreteRArray, args)

rf = @compile f(rargs...)  

When I run this, I get ERROR: MethodError: no method matching eigvals!(::Reactant.TracedRArray{Float64, 2}). Would there be any alternative way to express this calculation?

From the structure of the function, the matrix F is symmetric and the smallest eigevalues are the most important, although I don’t know how many of them would I need to calculate to get a good approximation. Also, in my use case, A will be tall.

Another issue that I had is that switching from diagm to Diagonal gives me ERROR: Scalar indexing is disallowed. Therefore, I would also be interested to know if there is a more efficient way to construct D in this setting.

Thanks in advance!

When I run this, I get ERROR: MethodError: no method matching eigvals!(::Reactant.TracedRArray{Float64, 2}) . Would there be any alternative way to express this calculation?

Ah right, LinearAlgebra.eigvals! is not yet implemented. Do you mind opening an issue in EnzymeAD/Reactant.jl? I can work on that let you know when is ready.

From the structure of the function, the matrix F is symmetric and the smallest eigevalues are the most important…

You should use LinearAlgebra.Symmetric on F then.

Another issue that I had is that switching from diagm to Diagonal gives me ERROR: Scalar indexing is disallowed. Therefore, I would also be interested to know if there is a more efficient way to construct D in this setting.

The error is saying that the elements of the TracedRArray are getting accessed through getindex, which is discouraged as it might be detrimental for performance on GPUs. If you really need it, you can disable the check for an expr with @allowscalar, but surely it should be easier as Diagonal(x) should just take x and not call getindex.

Do you mind showing the full stacktrace of this last error? I have the suspicion that it happens when calling show.

I’ve just opened the issue. Also, here is the full stack trace:

ERROR: Scalar indexing is disallowed.
Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error
    @ ./error.jl:35 [inlined]
  [2] (::Nothing)(none::typeof(error), none::String)
    @ Reactant ./<missing>:0
  [3] ErrorException
    @ ./boot.jl:282 [inlined]
  [4] error
    @ ./error.jl:35 [inlined]
  [5] call_with_reactant(::Reactant.MustThrowError, ::typeof(error), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
  [6] errorscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151 [inlined]
  [7] (::Nothing)(none::typeof(GPUArraysCore.errorscalar), none::String)
    @ Reactant ./<missing>:0
  [8] string
    @ ./strings/substring.jl:225 [inlined]
  [9] scalardesc
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:134 [inlined]
 [10] errorscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:150 [inlined]
 [11] call_with_reactant(::Reactant.MustThrowError, ::typeof(GPUArraysCore.errorscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [12] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124 [inlined]
 [13] (::Nothing)(none::typeof(GPUArraysCore._assertscalar), none::String, none::GPUArraysCore.ScalarIndexing)
    @ Reactant ./<missing>:0
 [14] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:123 [inlined]
 [15] call_with_reactant(::typeof(GPUArraysCore._assertscalar), ::String, ::GPUArraysCore.ScalarIndexing)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [16] assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112 [inlined]
 [17] (::Nothing)(none::typeof(GPUArraysCore.assertscalar), none::String)
    @ Reactant ./<missing>:0
 [18] current_task
    @ ./task.jl:150 [inlined]
 [19] task_local_storage
    @ ./task.jl:269 [inlined]
 [20] assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:97 [inlined]
 [21] call_with_reactant(::typeof(GPUArraysCore.assertscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [22] getindex
    @ ~/.julia/packages/Reactant/TF3tW/src/Indexing.jl:40 [inlined]
 [23] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 2}, none::Tuple{Int64, Int64})
    @ Reactant ./<missing>:0
 [24] getindex
    @ ~/.julia/packages/Reactant/TF3tW/src/Indexing.jl:40 [inlined]
 [25] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Float64, 2}, ::Int64, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [26] rmul!
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:301 [inlined]
 [27] (::Nothing)(none::typeof(rmul!), none::Reactant.TracedRArray{…}, none::Diagonal{…})
    @ Reactant ./<missing>:0
 [28] getproperty
    @ ./Base.jl:37 [inlined]
 [29] size
    @ ~/.julia/packages/Reactant/TF3tW/src/TracedRArray.jl:259 [inlined]
 [30] size
    @ ./abstractarray.jl:42 [inlined]
 [31] _muldiag_size_check
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:255 [inlined]
 [32] rmul!
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:298 [inlined]
 [33] call_with_reactant(::typeof(rmul!), ::Reactant.TracedRArray{…}, ::Diagonal{…})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [34] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:343 [inlined]
 [35] (::Nothing)(none::typeof(*), none::Adjoint{…}, none::Diagonal{…})
    @ Reactant ./<missing>:0
 [36] _any
    @ ./reduce.jl:1219 [inlined]
 [37] any
    @ ./reduce.jl:1235 [inlined]
 [38] TupleOrBottom
    @ ./promotion.jl:482 [inlined]
 [39] promote_op
    @ ./promotion.jl:498 [inlined]
 [40] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:342 [inlined]
 [41] call_with_reactant(::typeof(*), ::Adjoint{…}, ::Diagonal{…})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [42] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1174 [inlined]
 [43] (::Nothing)(none::typeof(LinearAlgebra._tri_matmul), none::Adjoint{…}, none::Diagonal{…}, none::Reactant.TracedRArray{…}, none::Nothing)
    @ Reactant ./<missing>:0
 [44] getproperty
    @ ./Base.jl:37 [inlined]
 [45] size
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/adjtrans.jl:321 [inlined]
 [46] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1166 [inlined]
 [47] call_with_reactant(::typeof(LinearAlgebra._tri_matmul), ::Adjoint{…}, ::Diagonal{…}, ::Reactant.TracedRArray{…}, ::Nothing)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [48] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1166 [inlined]
 [49] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1162 [inlined]
 [50] f
    @ ~/Code/optimal_phase_transformation/post.jl:5 [inlined]
 [51] (::Nothing)(none::typeof(f), none::Reactant.TracedRArray{Float64, 2}, none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [52] Diagonal
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:10 [inlined]
 [53] Diagonal
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:14 [inlined]
 [54] f
    @ ~/Code/optimal_phase_transformation/post.jl:4 [inlined]
 [55] call_with_reactant(::typeof(f), ::Reactant.TracedRArray{Float64, 2}, ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [56] make_mlir_fn(f::typeof(f), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/TF3tW/src/TracedUtils.jl:348
 [57] make_mlir_fn
    @ ~/.julia/packages/Reactant/TF3tW/src/TracedUtils.jl:277 [inlined]
 [58] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(f), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:1734
 [59] compile_mlir!
    @ ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:1696 [inlined]
 [60] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3691
 [61] compile_xla
    @ ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3663 [inlined]
 [62] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3767
Some type information was truncated. Use `show(err)` to see complete types.

Thanks for taking your time to look into it!