Are there alternatives to Diagonal and eigvals compatible with Reactant?

Hello!

I’m in the process of refactoring a complicated function so that it can be compiled by Reactant. It has not been very straightforward, but the speedup I’ve seen so far certainly justifies the effort.

Nonetheless, I’ve become stuck at the last step, where I need to calculate something along the lines of the function f defined in the following snippet:

using Reactant, LinearAlgebra

function f(A, diag_entries)
    D = diagm(diag_entries)
    F = A' * D * A
    sum(inv, eigvals(F))
end

A = randn(10^4, 10^2)
diag_entries = randn(10^4)

args = A, diag_entries
rargs = map(Reactant.ConcreteRArray, args)

rf = @compile f(rargs...)  

When I run this, I get ERROR: MethodError: no method matching eigvals!(::Reactant.TracedRArray{Float64, 2}). Would there be any alternative way to express this calculation?

From the structure of the function, the matrix F is symmetric and the smallest eigevalues are the most important, although I don’t know how many of them would I need to calculate to get a good approximation. Also, in my use case, A will be tall.

Another issue that I had is that switching from diagm to Diagonal gives me ERROR: Scalar indexing is disallowed. Therefore, I would also be interested to know if there is a more efficient way to construct D in this setting.

Thanks in advance!

When I run this, I get ERROR: MethodError: no method matching eigvals!(::Reactant.TracedRArray{Float64, 2}) . Would there be any alternative way to express this calculation?

Ah right, LinearAlgebra.eigvals! is not yet implemented. Do you mind opening an issue in EnzymeAD/Reactant.jl? I can work on that let you know when is ready.

From the structure of the function, the matrix F is symmetric and the smallest eigevalues are the most important…

You should use LinearAlgebra.Symmetric on F then.

Another issue that I had is that switching from diagm to Diagonal gives me ERROR: Scalar indexing is disallowed. Therefore, I would also be interested to know if there is a more efficient way to construct D in this setting.

The error is saying that the elements of the TracedRArray are getting accessed through getindex, which is discouraged as it might be detrimental for performance on GPUs. If you really need it, you can disable the check for an expr with @allowscalar, but surely it should be easier as Diagonal(x) should just take x and not call getindex.

Do you mind showing the full stacktrace of this last error? I have the suspicion that it happens when calling show.

I’ve just opened the issue. Also, here is the full stack trace:

ERROR: Scalar indexing is disallowed.
Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error
    @ ./error.jl:35 [inlined]
  [2] (::Nothing)(none::typeof(error), none::String)
    @ Reactant ./<missing>:0
  [3] ErrorException
    @ ./boot.jl:282 [inlined]
  [4] error
    @ ./error.jl:35 [inlined]
  [5] call_with_reactant(::Reactant.MustThrowError, ::typeof(error), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
  [6] errorscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:151 [inlined]
  [7] (::Nothing)(none::typeof(GPUArraysCore.errorscalar), none::String)
    @ Reactant ./<missing>:0
  [8] string
    @ ./strings/substring.jl:225 [inlined]
  [9] scalardesc
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:134 [inlined]
 [10] errorscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:150 [inlined]
 [11] call_with_reactant(::Reactant.MustThrowError, ::typeof(GPUArraysCore.errorscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [12] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:124 [inlined]
 [13] (::Nothing)(none::typeof(GPUArraysCore._assertscalar), none::String, none::GPUArraysCore.ScalarIndexing)
    @ Reactant ./<missing>:0
 [14] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:123 [inlined]
 [15] call_with_reactant(::typeof(GPUArraysCore._assertscalar), ::String, ::GPUArraysCore.ScalarIndexing)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [16] assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:112 [inlined]
 [17] (::Nothing)(none::typeof(GPUArraysCore.assertscalar), none::String)
    @ Reactant ./<missing>:0
 [18] current_task
    @ ./task.jl:150 [inlined]
 [19] task_local_storage
    @ ./task.jl:269 [inlined]
 [20] assertscalar
    @ ~/.julia/packages/GPUArraysCore/aNaXo/src/GPUArraysCore.jl:97 [inlined]
 [21] call_with_reactant(::typeof(GPUArraysCore.assertscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [22] getindex
    @ ~/.julia/packages/Reactant/TF3tW/src/Indexing.jl:40 [inlined]
 [23] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 2}, none::Tuple{Int64, Int64})
    @ Reactant ./<missing>:0
 [24] getindex
    @ ~/.julia/packages/Reactant/TF3tW/src/Indexing.jl:40 [inlined]
 [25] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Float64, 2}, ::Int64, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [26] rmul!
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:301 [inlined]
 [27] (::Nothing)(none::typeof(rmul!), none::Reactant.TracedRArray{…}, none::Diagonal{…})
    @ Reactant ./<missing>:0
 [28] getproperty
    @ ./Base.jl:37 [inlined]
 [29] size
    @ ~/.julia/packages/Reactant/TF3tW/src/TracedRArray.jl:259 [inlined]
 [30] size
    @ ./abstractarray.jl:42 [inlined]
 [31] _muldiag_size_check
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:255 [inlined]
 [32] rmul!
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:298 [inlined]
 [33] call_with_reactant(::typeof(rmul!), ::Reactant.TracedRArray{…}, ::Diagonal{…})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [34] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:343 [inlined]
 [35] (::Nothing)(none::typeof(*), none::Adjoint{…}, none::Diagonal{…})
    @ Reactant ./<missing>:0
 [36] _any
    @ ./reduce.jl:1219 [inlined]
 [37] any
    @ ./reduce.jl:1235 [inlined]
 [38] TupleOrBottom
    @ ./promotion.jl:482 [inlined]
 [39] promote_op
    @ ./promotion.jl:498 [inlined]
 [40] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:342 [inlined]
 [41] call_with_reactant(::typeof(*), ::Adjoint{…}, ::Diagonal{…})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [42] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1174 [inlined]
 [43] (::Nothing)(none::typeof(LinearAlgebra._tri_matmul), none::Adjoint{…}, none::Diagonal{…}, none::Reactant.TracedRArray{…}, none::Nothing)
    @ Reactant ./<missing>:0
 [44] getproperty
    @ ./Base.jl:37 [inlined]
 [45] size
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/adjtrans.jl:321 [inlined]
 [46] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1166 [inlined]
 [47] call_with_reactant(::typeof(LinearAlgebra._tri_matmul), ::Adjoint{…}, ::Diagonal{…}, ::Reactant.TracedRArray{…}, ::Nothing)
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [48] _tri_matmul
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1166 [inlined]
 [49] *
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:1162 [inlined]
 [50] f
    @ ~/Code/optimal_phase_transformation/post.jl:5 [inlined]
 [51] (::Nothing)(none::typeof(f), none::Reactant.TracedRArray{Float64, 2}, none::Reactant.TracedRArray{Float64, 1})
    @ Reactant ./<missing>:0
 [52] Diagonal
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:10 [inlined]
 [53] Diagonal
    @ ~/.julia/juliaup/julia-1.10.10+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:14 [inlined]
 [54] f
    @ ~/Code/optimal_phase_transformation/post.jl:4 [inlined]
 [55] call_with_reactant(::typeof(f), ::Reactant.TracedRArray{Float64, 2}, ::Reactant.TracedRArray{Float64, 1})
    @ Reactant ~/.julia/packages/Reactant/TF3tW/src/utils.jl:0
 [56] make_mlir_fn(f::typeof(f), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, within_autodiff::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/TF3tW/src/TracedUtils.jl:348
 [57] make_mlir_fn
    @ ~/.julia/packages/Reactant/TF3tW/src/TracedUtils.jl:277 [inlined]
 [58] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(f), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:1734
 [59] compile_mlir!
    @ ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:1696 [inlined]
 [60] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3691
 [61] compile_xla
    @ ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3663 [inlined]
 [62] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/TF3tW/src/Compiler.jl:3767
Some type information was truncated. Use `show(err)` to see complete types.

Thanks for taking your time to look into it!

It’s not what I thought. Mind opening another issue for this?

cc @avikpal presumably this can just call our symm/etc mul ops directly [and we should add an overload?]

You should use LinearAlgebra.Symmetric on F then.

F is captured in the IR itself, so we shouldn’t need a manual Symmetric annotation. That said we should add a pattern that optimizes eigvals(x) where x is proven symmetric

That said we should add a pattern that optimizes eigvals(x) where x is proven symmetric

mmm now that I remember, eigendecomposition already assumes that the matrix is Hermitian or symmetric. if not, you will end up with left- and right-eigenvectors but same eigenvalues. eigvals doesn’t need an optimization.

a possible optimization, but that it can be too niche, is that if we prove that x matrix is Hermitian/Symmetric and we call inv on the eigenvectors, then the inv can be replaced for transpose (or det = 1, because eigenvectors then form a unitary matrix).

Here it is. Weirdly, I only get the error when multiplying three matrices at the same line. If I store the intermediate result, it runs fine.

Note that this function is the same as the trace of the inverse.

trinv(A, d) = tr(inv(Hermitian(A' * Diagonal(d) * A)))

which is more efficient than using eigenvalues — it is about 10x faster for a 1000x1000 matrix on my machine. It should also be easier to differentiate.

julia> A = rand(5,5); d = rand(5);

julia> f(A, d)
557.778533254039

julia> trinv(A, d)
557.7785332540502

If the diag_entries are nonnegative and A is full column rank, so that you know F is positive-definite, I would instead use the Cholesky factorization to be a bit faster still:

trinv_posdef(A, d) = tr(inv(cholesky(Hermitian(A' * Diagonal(d) * A))))