no the issue is zeros(…) always return a CPU julia array so we are creating a CPU array of XPU numbers which is really dirty, so we indeed want something like zero(X) that keeps the same meomry layout but fix elements to 0.
note : its actually pretty good at removing our mess though
func.func @main(%arg0: tensor<675xf32>, %arg1: tensor<200x4xf32>, %arg2: tensor<9xf32>, %arg3: tensor<5xf32>, %arg4: tensor<4xf32>) -> tensor<f32> {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<3x100xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<16x100xf32>
%cst_1 = stablehlo.constant dense<0.00999999977> : tensor<16x100xf32>
%cst_2 = stablehlo.constant dense<0.111111112> : tensor<f32>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.slice %arg1 [100:200, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
%1 = stablehlo.slice %arg1 [0:100, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
%2 = stablehlo.slice %arg0 [64:80] : (tensor<675xf32>) -> tensor<16xf32>
%3 = stablehlo.slice %arg0 [0:64] : (tensor<675xf32>) -> tensor<64xf32>
%4 = stablehlo.reshape %3 : (tensor<64xf32>) -> tensor<4x16xf32>
%5 = stablehlo.dot_general %4, %1, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
%6 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
%7 = stablehlo.add %5, %6 : tensor<16x100xf32>
%8 = stablehlo.compare GT, %7, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%9 = stablehlo.multiply %cst_1, %7 : tensor<16x100xf32>
%10 = stablehlo.select %8, %7, %9 : tensor<16x100xi1>, tensor<16x100xf32>
%11 = stablehlo.slice %arg0 [336:352] : (tensor<675xf32>) -> tensor<16xf32>
%12 = stablehlo.slice %arg0 [80:336] : (tensor<675xf32>) -> tensor<256xf32>
%13 = stablehlo.reshape %12 : (tensor<256xf32>) -> tensor<16x16xf32>
%14 = stablehlo.dot_general %13, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
%15 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
%16 = stablehlo.add %14, %15 : tensor<16x100xf32>
%17 = stablehlo.compare GT, %16, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%18 = stablehlo.multiply %cst_1, %16 : tensor<16x100xf32>
%19 = stablehlo.select %17, %16, %18 : tensor<16x100xi1>, tensor<16x100xf32>
%20 = stablehlo.slice %arg0 [608:624] : (tensor<675xf32>) -> tensor<16xf32>
%21 = stablehlo.slice %arg0 [352:608] : (tensor<675xf32>) -> tensor<256xf32>
%22 = stablehlo.reshape %21 : (tensor<256xf32>) -> tensor<16x16xf32>
%23 = stablehlo.dot_general %22, %19, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
%24 = stablehlo.broadcast_in_dim %20, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
%25 = stablehlo.add %23, %24 : tensor<16x100xf32>
%26 = stablehlo.compare GT, %25, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%27 = stablehlo.multiply %cst_1, %25 : tensor<16x100xf32>
%28 = stablehlo.select %26, %25, %27 : tensor<16x100xi1>, tensor<16x100xf32>
%29 = stablehlo.slice %arg0 [672:675] : (tensor<675xf32>) -> tensor<3xf32>
%30 = stablehlo.slice %arg0 [624:672] : (tensor<675xf32>) -> tensor<48xf32>
%31 = stablehlo.reshape %30 : (tensor<48xf32>) -> tensor<16x3xf32>
%32 = stablehlo.dot_general %31, %28, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
%33 = stablehlo.broadcast_in_dim %29, dims = [0] : (tensor<3xf32>) -> tensor<3x100xf32>
%34 = stablehlo.add %32, %33 : tensor<3x100xf32>
%35 = stablehlo.maximum %cst, %34 : tensor<3x100xf32>
%36 = stablehlo.reduce(%35 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
%37 = stablehlo.cosine %arg3 : tensor<5xf32>
%38 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor<f32>) -> tensor<5xf32>
%39 = stablehlo.dot_general %4, %0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
%40 = stablehlo.add %39, %6 : tensor<16x100xf32>
%41 = stablehlo.compare GT, %40, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%42 = stablehlo.multiply %cst_1, %40 : tensor<16x100xf32>
%43 = stablehlo.select %41, %40, %42 : tensor<16x100xi1>, tensor<16x100xf32>
%44 = stablehlo.dot_general %13, %43, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
%45 = stablehlo.add %44, %15 : tensor<16x100xf32>
%46 = stablehlo.compare GT, %45, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%47 = stablehlo.multiply %cst_1, %45 : tensor<16x100xf32>
%48 = stablehlo.select %46, %45, %47 : tensor<16x100xi1>, tensor<16x100xf32>
%49 = stablehlo.dot_general %22, %48, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
%50 = stablehlo.add %49, %24 : tensor<16x100xf32>
%51 = stablehlo.compare GT, %50, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
%52 = stablehlo.multiply %cst_1, %50 : tensor<16x100xf32>
%53 = stablehlo.select %51, %50, %52 : tensor<16x100xi1>, tensor<16x100xf32>
%54 = stablehlo.dot_general %31, %53, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
%55 = stablehlo.add %54, %33 : tensor<3x100xf32>
%56 = stablehlo.maximum %cst, %55 : tensor<3x100xf32>
%57 = stablehlo.reduce(%56 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
%58 = stablehlo.cosine %arg4 : tensor<4xf32>
%59 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<f32>) -> tensor<4xf32>
%60 = stablehlo.concatenate %38, %59, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
%61 = stablehlo.concatenate %37, %58, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
%62 = stablehlo.multiply %60, %61 : tensor<9xf32>
%63 = stablehlo.subtract %arg2, %62 : tensor<9xf32>
%64 = stablehlo.multiply %63, %63 : tensor<9xf32>
%65 = stablehlo.reduce(%64 init: %cst_3) applies stablehlo.add across dimensions = [0] : (tensor<9xf32>, tensor<f32>) -> tensor<f32>
%66 = stablehlo.multiply %65, %cst_2 : tensor<f32>
return %66 : tensor<f32>
}
}
I don’t see any <1> tensors so it seems like it figured it out somehow