I get a little better but it may be specific to the hardware :
1.317 s (51 allocations: 686.72 MiB)
1.178 s (1256 allocations: 57.74 KiB)
and with Float32 :
636.211 ms (75 allocations: 515.07 MiB)
41.175 ms (1325 allocations: 45.65 KiB)
code :
using CUDA, Statistics, LinearAlgebra, BenchmarkTools, Random
function operation(
ξ::AbstractArray{T}, y, obs;
dev = identity,
cov_obs = dev(Diagonal(fill(T(0.1), length(obs))))
) where T
dim_y, batch = size(y)
ξ_mean = mean(ξ, dims=2)
y_mean = mean(y, dims=2)
ξ_centered = ξ .- ξ_mean
y_centered = y .- y_mean
C_ξy = (ξ_centered * y_centered') ./ (batch - 1)
C_yy = (y_centered * y_centered') ./ (batch - 1)
S = cholesky!(Symmetric(C_yy .+ cov_obs))
innovation = obs .- y
correction = C_ξy * (S \ innovation)
ξ .+= correction
return ξ
end
function operation(
ξ::AbstractArray{T}, y, obs, L;
dev = identity,
cov_obs = dev(Diagonal(fill(T(0.1), L)))
) where T
return operation(ξ, y, obs; dev, cov_obs)
end
rng = Xoshiro(42)
dim_ξ, dim_y, batch = 3000, 3000, 3_000
ξ_cpu = randn(rng, Float64, dim_ξ, batch)
y_cpu = randn(rng, Float64, dim_y, batch) * 10
obs_cpu = randn(rng, Float64, dim_y, 1) * 10
ξ_gpu = CuArray(ξ_cpu)
y_gpu = CuArray(y_cpu)
obs_gpu = CuArray(obs_cpu)
# Benchmark CPU operation
@btime operation($ξ_cpu, $y_cpu, $obs_cpu);
# Benchmark GPU operation
@btime CUDA.@sync operation($ξ_gpu, $y_gpu, $obs_gpu; dev = CuArray);
I don’t think we can do a lot better without fusion though.
Actually, seems like its better to not fuse, funny one,
fuse version (linux only)
using Reactant, Statistics, LinearAlgebra, BenchmarkTools, Random
import Reactant: to_rarray
Reactant.set_default_backend("gpu")
using Statistics
using LinearAlgebra
function operation(
ξ::AbstractArray{T}, y::AbstractArray{T}, obs::AbstractArray{T};
dev = identity,
cov_obs = dev(Diagonal(fill(T(0.1), length(obs))))
) where T
dim_y, batch = size(y)
ξ_mean = mean(ξ, dims=2)
y_mean = mean(y, dims=2)
ξ_centered = ξ .- ξ_mean
y_centered = y .- y_mean
C_ξy = (ξ_centered * y_centered') ./ (batch - 1)
C_yy = (y_centered * y_centered') ./ (batch - 1)
S = cholesky!(Symmetric(C_yy .+ cov_obs))
innovation = obs .- y
correction = C_ξy * (S \ innovation)
ξ .+= correction
return ξ
end
function operation(
ξ::AbstractArray{T}, y, obs, L;
dev = identity,
cov_obs = dev(Diagonal(fill(T(0.1), L)))
) where T
return operation(ξ, y, obs; dev, cov_obs)
end
rng = Xoshiro(42)
dim_ξ, dim_y, batch = 3000, 3000, 3_000
ξ_cpu = randn(rng, Float64, dim_ξ, batch)
y_cpu = randn(rng, Float64, dim_y, batch) * 10
obs_cpu = randn(rng, Float64, dim_y, 1) * 10
ξ_gpu = to_rarray(ξ_cpu)
y_gpu = to_rarray(y_cpu)
obs_gpu = to_rarray(obs_cpu)
# Benchmark CPU operation
@btime operation($ξ_cpu, $y_cpu, $obs_cpu; dev = $identity);
L = length(obs_cpu)
# Benchmark GPU operation
op_comp = @compile operation(ξ_gpu, y_gpu, obs_gpu, L; dev = Reactant.TracedRArray{Float32, 1})
op_comp(ξ_gpu, y_gpu, obs_gpu, L)
@btime Reactant.synchronize($op_comp($ξ_gpu, $y_gpu, $obs_gpu, $L));
gives
1.249 s (51 allocations: 686.75 MiB)
1.256 s (19 allocations: 624 bytes)
the hlo doesn’t look too bad though
module @reactant_operation attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<3000x3000xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}, %arg1: tensor<3000x3000xf64> {enzymexla.memory_effects = []}, %arg2: tensor<1x3000xf64> {enzymexla.memory_effects = []}) -> tensor<3000x3000xf64> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<3.3333333333333332E-4> : tensor<3000xf64>
%cst_0 = stablehlo.constant dense<3.3344448149383126E-4> : tensor<3000x3000xf64>
%cst_1 = stablehlo.constant dense<1.000000e-01> : tensor<3000xf64>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<3000x3000xf64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%2 = stablehlo.iota dim = 0 : tensor<3000x2xi64>
%3 = "stablehlo.scatter"(%cst_3, %2, %cst_1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
stablehlo.return %arg4 : tensor<f64>
}) {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<3000x3000xf64>, tensor<3000x2xi64>, tensor<3000xf64>) -> tensor<3000x3000xf64>
%4 = stablehlo.reduce(%arg0 init: %cst_2) applies stablehlo.add across dimensions = [0] : (tensor<3000x3000xf64>, tensor<f64>) -> tensor<3000xf64>
%5 = stablehlo.multiply %4, %cst : tensor<3000xf64>
%6 = stablehlo.reduce(%arg1 init: %cst_2) applies stablehlo.add across dimensions = [0] : (tensor<3000x3000xf64>, tensor<f64>) -> tensor<3000xf64>
%7 = stablehlo.multiply %6, %cst : tensor<3000xf64>
%8 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<3000xf64>) -> tensor<3000x3000xf64>
%9 = stablehlo.subtract %0, %8 : tensor<3000x3000xf64>
%10 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<3000xf64>) -> tensor<3000x3000xf64>
%11 = stablehlo.subtract %1, %10 : tensor<3000x3000xf64>
%12 = stablehlo.dot_general %9, %11, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%13 = stablehlo.dot_general %11, %11, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%14 = stablehlo.multiply %13, %cst_0 : tensor<3000x3000xf64>
%15 = stablehlo.add %14, %3 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : tensor<3000x3000xf64>
%16 = stablehlo.cholesky %15 : tensor<3000x3000xf64>
%17 = stablehlo.broadcast_in_dim %arg2, dims = [1, 0] : (tensor<1x3000xf64>) -> tensor<3000x3000xf64>
%18 = stablehlo.subtract %17, %1 : tensor<3000x3000xf64>
%19 = "stablehlo.triangular_solve"(%16, %18) <{left_side = true, lower = false, transpose_a = #stablehlo<transpose ADJOINT>, unit_diagonal = false}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%20 = "stablehlo.triangular_solve"(%16, %19) <{left_side = true, lower = false, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%21 = stablehlo.dot_general %20, %12, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3000x3000xf64>, tensor<3000x3000xf64>) -> tensor<3000x3000xf64>
%22 = stablehlo.multiply %cst_0, %21 : tensor<3000x3000xf64>
%23 = stablehlo.add %arg0, %22 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<3000x3000xf64>
return %23 : tensor<3000x3000xf64>
}
}
on F32 it is on part :
julia> @btime Reactant.synchronize($op_comp($ξ_gpu, $y_gpu, $obs_gpu, $L, Reactant.TracedRArray{Float32, 1}));
45.400 ms (14 allocations: 416 bytes)
maybe its interesting for @wsmoses
ps : \ didn’t work I had to use inv() * even though the hlo ends up factorizing anyway so its fine