I get a little better but it may be specific to the hardware :
1.536 s (75 allocations: 1.01 GiB)
933.910 ms (1370 allocations: 58.83 KiB)
and with Float32 :
636.211 ms (75 allocations: 515.07 MiB)
41.175 ms (1325 allocations: 45.65 KiB)
code :
""
using CUDA, Statistics, LinearAlgebra, BenchmarkTools
function operation(
ξ::AbstractArray{T}, y, obs, dev = identity;
cov_obs = dev(Diagonal(fill(T(0.1), length(obs))))
) where T
dim_y, batch = size(y)
ξ_mean = mean(ξ, dims=2)
y_mean = mean(y, dims=2)
C_ξy = (ξ * y' - batch .* ξ_mean * y_mean') ./ ( batch - 1)
C_yy = (y * y' - batch .* y_mean * y_mean') ./ ( batch - 1)
ξ += C_ξy * ((C_yy + cov_obs) \ (obs .- y + cov_obs))
return ξ
end
using CUDA, Random, LinearAlgebra, Statistics, BenchmarkTools
rng = Xoshiro(42)
dim_ξ, dim_y, batch = 3000, 3000, 3000
ξ_cpu = randn(rng, Float32, dim_ξ, batch)
y_cpu = randn(rng, Float32, dim_y, batch) * 10
obs_cpu = randn(rng, Float32, dim_y, 1) * 10
ξ_gpu = CuArray(ξ_cpu)
y_gpu = CuArray(y_cpu)
obs_gpu = CuArray(obs_cpu)
# Benchmark CPU operation
@btime operation($ξ_cpu, $y_cpu, $obs_cpu);
# Benchmark GPU operation
@btime @sync operation($ξ_gpu, $y_gpu, $obs_gpu, CuArray);
I don’t think we can do a lot better without fusion though.
Actually, seems like its better to not fuse, funny one,
fuse version (linux only)
""
using Reactant, Statistics, LinearAlgebra, BenchmarkTools, Random
import Reactant: to_rarray
function operation(
ξ::AbstractArray{T}, y, obs, dev;
cov_obs = dev(Diagonal(fill(T(0.1), length(obs))))
) where T
dim_y, batch = size(y)
ξ_mean = mean(ξ, dims=2)
y_mean = mean(y, dims=2)
C_ξy = (ξ * y' - batch .* ξ_mean * y_mean') ./ ( batch - 1)
C_yy = (y * y' - batch .* y_mean * y_mean') ./ ( batch - 1)
ξ += C_ξy * (inv(C_yy + cov_obs) * (obs .- y + cov_obs))
return ξ
end
function operation(
ξ::AbstractArray{T}, y, obs, L, dev;
cov_obs = dev(Diagonal(fill(T(0.1), L)))
) where T
dim_y, batch = size(y)
ξ_mean = mean(ξ, dims=2)
y_mean = mean(y, dims=2)
C_ξy = (ξ * y' - batch .* ξ_mean * y_mean') ./ ( batch - 1)
C_yy = (y * y' - batch .* y_mean * y_mean') ./ ( batch - 1)
ξ += C_ξy * ((C_yy + cov_obs) \ (obs .- y + cov_obs))
return ξ
end
rng = Xoshiro(42)
dim_ξ, dim_y, batch = 3000, 3000, 3000
ξ_cpu = randn(rng, Float64, dim_ξ, batch)
y_cpu = randn(rng, Float64, dim_y, batch) * 10
obs_cpu = randn(rng, Float64, dim_y, 1) * 10
ξ_gpu = to_rarray(ξ_cpu)
y_gpu = to_rarray(y_cpu)
obs_gpu = to_rarray(obs_cpu)
# Benchmark CPU operation
@btime operation($ξ_cpu, $y_cpu, $obs_cpu, $identity);
L = length(obs_cpu)
# Benchmark GPU operation
op_comp = @compile operation(ξ_gpu, y_gpu, obs_gpu, L, Reactant.TracedRArray{Float64, 1})
@btime Reactant.synchronize($op_comp($ξ_gpu, $y_gpu, $obs_gpu, $L, Reactant.TracedRArray{Float64, 1}));
gives
julia> @benchmark Reactant.synchronize($op_comp($ξ_gpu, $y_gpu, $obs_gpu, $L, Reactant.TracedRArray{Float64, 1}))
BenchmarkTools.Trial: 4 samples with 1 evaluation per sample.
Range (min … max): 1.422 s … 1.424 s ┊ GC (min … max): 0.00% … 0.00%
Time (median): 1.424 s ┊ GC (median): 0.00%
Time (mean ± σ): 1.423 s ± 1.014 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
█ █ █ █
█▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁█ ▁
1.42 s Histogram: frequency by time 1.42 s <
Memory estimate: 416 bytes, allocs estimate: 14.
the hlo doesn’t look too bad though
julia> @code_hlo operation(ξ_gpu, y_gpu, obs_gpu, Reactant.TracedRArray{Float64, 1})
module @reactant_operation attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<3000x3000xf64> {enzymexla.memory_effects = []}, %arg1: tensor<3000x3000xf64> {enzymexla.memory_effects = []}, %arg2: tensor<1x3000xf64> {enzymexla.memory_effects = []}) -> tensor<3000x3000xf64> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<3.3333333333333332E-4> : tensor<3000xf64>
%c = stablehlo.constant dense<1> : tensor<3000xi32>
%c_0 = stablehlo.constant dense<1> : tensor<3000x1xi32>
%cst_1 = stablehlo.constant dense<3.3344448149383126E-4> : tensor<3000x3000xf64>
%cst_2 = stablehlo.constant dense<-3.000000e+03> : tensor<3000x3000xf64>
%c_3 = stablehlo.constant dense<1> : tensor<3000x1xi64>
%c_4 = stablehlo.constant dense<1> : tensor<3000x3000xi64>
%c_5 = stablehlo.constant dense<-3000> : tensor<3000x3000xi64>
%cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<3000xf64>
%c_7 = stablehlo.constant dense<3000> : tensor<3000x3000xi64>
%c_8 = stablehlo.constant dense<-1> : tensor<3000x3000xi64>
%c_9 = stablehlo.constant dense<true> : tensor<i1>
%cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%cst_11 = stablehlo.constant {enzymexla.non_negative = [#enzymexla<guaranteed GUARANTEED>]} dense<0.000000e+00> : tensor<3000x3000xf64>
%cst_12 = stablehlo.constant dense<1.000000e-01> : tensor<3000xf64>
%0 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%1 = stablehlo.iota dim = 0 : tensor<3000x2xi64>
%2 = "stablehlo.scatter"(%cst_11, %1, %cst_12) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
stablehlo.return %arg4 : tensor<f64>
}) : (tensor<3000x3000xf64>, tensor<3000x2xi64>, tensor<3000xf64>) -> tensor<3000x3000xf64>
%3 = stablehlo.reduce(%arg0 init: %cst_10) applies stablehlo.add across dimensions = [0] : (tensor<3000x3000xf64>, tensor<f64>) -> tensor<3000xf64>
%4 = stablehlo.reduce(%arg1 init: %cst_10) applies stablehlo.add across dimensions = [0] : (tensor<3000x3000xf64>, tensor<f64>) -> tensor<3000xf64>
%5 = stablehlo.multiply %4, %cst : tensor<3000xf64>
%6 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%7 = stablehlo.dot_general %3, %5, contracting_dims = [] x [], precision = [DEFAULT, DEFAULT] : (tensor<3000xf64>, tensor<3000xf64>) -> tensor<3000x3000xf64>
%8 = stablehlo.subtract %6, %7 : tensor<3000x3000xf64>
%9 = stablehlo.dot_general %arg1, %arg1, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%10 = stablehlo.dot_general %5, %5, contracting_dims = [] x [], precision = [DEFAULT, DEFAULT] : (tensor<3000xf64>, tensor<3000xf64>) -> tensor<3000x3000xf64>
%11 = stablehlo.multiply %cst_2, %10 : tensor<3000x3000xf64>
%12 = stablehlo.add %9, %11 : tensor<3000x3000xf64>
%13 = stablehlo.multiply %12, %cst_1 : tensor<3000x3000xf64>
%14 = stablehlo.add %13, %2 {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xf64>
%15 = stablehlo.iota dim = 0 : tensor<3000x3000xi64>
%16 = stablehlo.iota dim = 1 : tensor<3000x3000xi64>
%17 = stablehlo.subtract %16, %c_8 : tensor<3000x3000xi64>
%18 = stablehlo.compare GE, %15, %17 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%19 = stablehlo.select %18, %14, %cst_11 {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xi1>, tensor<3000x3000xf64>
%20 = stablehlo.compare EQ, %19, %cst_11 : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xi1>
%21 = stablehlo.reduce(%20 init: %c_9) applies stablehlo.and across dimensions = [0, 1] : (tensor<3000x3000xi1>, tensor<i1>) -> tensor<i1>
%22 = stablehlo.subtract %16, %c_7 : tensor<3000x3000xi64>
%23 = stablehlo.compare LE, %15, %22 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%24 = stablehlo.select %23, %14, %cst_11 {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xi1>, tensor<3000x3000xf64>
%25 = stablehlo.compare EQ, %24, %cst_11 : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xi1>
%26 = stablehlo.reduce(%25 init: %c_9) applies stablehlo.and across dimensions = [0, 1] : (tensor<3000x3000xi1>, tensor<i1>) -> tensor<i1>
%27 = stablehlo.and %21, %26 : tensor<i1>
%28 = "stablehlo.scatter"(%cst_11, %1, %cst_6) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
stablehlo.return %arg4 : tensor<f64>
}) : (tensor<3000x3000xf64>, tensor<3000x2xi64>, tensor<3000xf64>) -> tensor<3000x3000xf64>
%29 = "stablehlo.if"(%27) ({
%37 = "stablehlo.triangular_solve"(%14, %28) <{left_side = false, lower = false, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%38 = stablehlo.compare LE, %15, %16 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%39 = stablehlo.select %38, %37, %cst_11 : tensor<3000x3000xi1>, tensor<3000x3000xf64>
stablehlo.return %39 : tensor<3000x3000xf64>
}, {
%37 = stablehlo.subtract %16, %c_5 : tensor<3000x3000xi64>
%38 = stablehlo.compare GE, %15, %37 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%39 = stablehlo.select %38, %14, %cst_11 {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xi1>, tensor<3000x3000xf64>
%40 = stablehlo.compare EQ, %39, %cst_11 : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xi1>
%41 = stablehlo.reduce(%40 init: %c_9) applies stablehlo.and across dimensions = [0, 1] : (tensor<3000x3000xi1>, tensor<i1>) -> tensor<i1>
%42 = stablehlo.subtract %16, %c_4 : tensor<3000x3000xi64>
%43 = stablehlo.compare LE, %15, %42 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%44 = stablehlo.select %43, %14, %cst_11 {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xi1>, tensor<3000x3000xf64>
%45 = stablehlo.compare EQ, %44, %cst_11 : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xi1>
%46 = stablehlo.reduce(%45 init: %c_9) applies stablehlo.and across dimensions = [0, 1] : (tensor<3000x3000xi1>, tensor<i1>) -> tensor<i1>
%47 = stablehlo.and %41, %46 : tensor<i1>
%48 = "stablehlo.if"(%47) ({
%49 = "stablehlo.triangular_solve"(%14, %28) <{left_side = false, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%50 = stablehlo.compare GE, %15, %16 : (tensor<3000x3000xi64>, tensor<3000x3000xi64>) -> tensor<3000x3000xi1>
%51 = stablehlo.select %50, %49, %cst_11 : tensor<3000x3000xi1>, tensor<3000x3000xf64>
stablehlo.return %51 : tensor<3000x3000xf64>
}, {
%49:3 = stablehlo.custom_call @cusolver_getrf_ffi(%14) {api_version = 4 : i32, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [0], operand_index = 0, operand_tuple_indices = []>], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3000x3000xf64>) -> (tensor<3000x3000xf64>, tensor<3000xi32>, tensor<i32>)
%50 = stablehlo.subtract %49#1, %c : tensor<3000xi32>
%51 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%50) {api_version = 4 : i32, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3000xi32>) -> tensor<3000xi32>
%52 = stablehlo.reshape %51 : (tensor<3000xi32>) -> tensor<3000x1xi32>
%53 = stablehlo.add %52, %c_0 : tensor<3000x1xi32>
%54 = stablehlo.convert %53 : (tensor<3000x1xi32>) -> tensor<3000x1xi64>
%55 = stablehlo.subtract %54, %c_3 : tensor<3000x1xi64>
%56 = "stablehlo.gather"(%28, %55) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 3000>}> : (tensor<3000x3000xf64>, tensor<3000x1xi64>) -> tensor<3000x3000xf64>
%57 = "stablehlo.triangular_solve"(%49#0, %56) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = true}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%58 = "stablehlo.triangular_solve"(%49#0, %57) <{left_side = true, lower = false, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
stablehlo.return %58 : tensor<3000x3000xf64>
}) : (tensor<i1>) -> tensor<3000x3000xf64>
stablehlo.return %48 : tensor<3000x3000xf64>
}) : (tensor<i1>) -> tensor<3000x3000xf64>
%30 = stablehlo.broadcast_in_dim %arg2, dims = [1, 0] : (tensor<1x3000xf64>) -> tensor<3000x3000xf64>
%31 = stablehlo.subtract %30, %0 : tensor<3000x3000xf64>
%32 = stablehlo.add %31, %2 : tensor<3000x3000xf64>
%33 = stablehlo.dot_general %29, %32, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%34 = stablehlo.dot_general %33, %8, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%35 = stablehlo.multiply %cst_1, %34 : tensor<3000x3000xf64>
%36 = stablehlo.add %arg0, %35 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xf64>
return %36 : tensor<3000x3000xf64>
}
}
on F32 it is on part :
julia> @btime Reactant.synchronize($op_comp($ξ_gpu, $y_gpu, $obs_gpu, $L, Reactant.TracedRArray{Float32, 1}));
45.400 ms (14 allocations: 416 bytes)
maybe its interesting for @wsmoses
ps : \ didn’t work I had to use inv() * even though the hlo ends up factorizing anyway so its fine