Using Reactant with Lux and Enzyme to speed up training in physics context

no the issue is zeros(…) always return a CPU julia array so we are creating a CPU array of XPU numbers which is really dirty, so we indeed want something like zero(X) that keeps the same meomry layout but fix elements to 0.

note : its actually pretty good at removing our mess though

  func.func @main(%arg0: tensor<675xf32>, %arg1: tensor<200x4xf32>, %arg2: tensor<9xf32>, %arg3: tensor<5xf32>, %arg4: tensor<4xf32>) -> tensor<f32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<3x100xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<16x100xf32>
    %cst_1 = stablehlo.constant dense<0.00999999977> : tensor<16x100xf32>
    %cst_2 = stablehlo.constant dense<0.111111112> : tensor<f32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.slice %arg1 [100:200, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
    %1 = stablehlo.slice %arg1 [0:100, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
    %2 = stablehlo.slice %arg0 [64:80] : (tensor<675xf32>) -> tensor<16xf32>
    %3 = stablehlo.slice %arg0 [0:64] : (tensor<675xf32>) -> tensor<64xf32>
    %4 = stablehlo.reshape %3 : (tensor<64xf32>) -> tensor<4x16xf32>
    %5 = stablehlo.dot_general %4, %1, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
    %6 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %7 = stablehlo.add %5, %6 : tensor<16x100xf32>
    %8 = stablehlo.compare  GT, %7, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %9 = stablehlo.multiply %cst_1, %7 : tensor<16x100xf32>
    %10 = stablehlo.select %8, %7, %9 : tensor<16x100xi1>, tensor<16x100xf32>
    %11 = stablehlo.slice %arg0 [336:352] : (tensor<675xf32>) -> tensor<16xf32>
    %12 = stablehlo.slice %arg0 [80:336] : (tensor<675xf32>) -> tensor<256xf32>
    %13 = stablehlo.reshape %12 : (tensor<256xf32>) -> tensor<16x16xf32>
    %14 = stablehlo.dot_general %13, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %15 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %16 = stablehlo.add %14, %15 : tensor<16x100xf32>
    %17 = stablehlo.compare  GT, %16, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %18 = stablehlo.multiply %cst_1, %16 : tensor<16x100xf32>
    %19 = stablehlo.select %17, %16, %18 : tensor<16x100xi1>, tensor<16x100xf32>
    %20 = stablehlo.slice %arg0 [608:624] : (tensor<675xf32>) -> tensor<16xf32>
    %21 = stablehlo.slice %arg0 [352:608] : (tensor<675xf32>) -> tensor<256xf32>
    %22 = stablehlo.reshape %21 : (tensor<256xf32>) -> tensor<16x16xf32>
    %23 = stablehlo.dot_general %22, %19, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %24 = stablehlo.broadcast_in_dim %20, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %25 = stablehlo.add %23, %24 : tensor<16x100xf32>
    %26 = stablehlo.compare  GT, %25, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %27 = stablehlo.multiply %cst_1, %25 : tensor<16x100xf32>
    %28 = stablehlo.select %26, %25, %27 : tensor<16x100xi1>, tensor<16x100xf32>
    %29 = stablehlo.slice %arg0 [672:675] : (tensor<675xf32>) -> tensor<3xf32>
    %30 = stablehlo.slice %arg0 [624:672] : (tensor<675xf32>) -> tensor<48xf32>
    %31 = stablehlo.reshape %30 : (tensor<48xf32>) -> tensor<16x3xf32>
    %32 = stablehlo.dot_general %31, %28, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
    %33 = stablehlo.broadcast_in_dim %29, dims = [0] : (tensor<3xf32>) -> tensor<3x100xf32>
    %34 = stablehlo.add %32, %33 : tensor<3x100xf32>
    %35 = stablehlo.maximum %cst, %34 : tensor<3x100xf32>
    %36 = stablehlo.reduce(%35 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
    %37 = stablehlo.cosine %arg3 : tensor<5xf32>
    %38 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor<f32>) -> tensor<5xf32>
    %39 = stablehlo.dot_general %4, %0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
    %40 = stablehlo.add %39, %6 : tensor<16x100xf32>
    %41 = stablehlo.compare  GT, %40, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %42 = stablehlo.multiply %cst_1, %40 : tensor<16x100xf32>
    %43 = stablehlo.select %41, %40, %42 : tensor<16x100xi1>, tensor<16x100xf32>
    %44 = stablehlo.dot_general %13, %43, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %45 = stablehlo.add %44, %15 : tensor<16x100xf32>
    %46 = stablehlo.compare  GT, %45, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %47 = stablehlo.multiply %cst_1, %45 : tensor<16x100xf32>
    %48 = stablehlo.select %46, %45, %47 : tensor<16x100xi1>, tensor<16x100xf32>
    %49 = stablehlo.dot_general %22, %48, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %50 = stablehlo.add %49, %24 : tensor<16x100xf32>
    %51 = stablehlo.compare  GT, %50, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %52 = stablehlo.multiply %cst_1, %50 : tensor<16x100xf32>
    %53 = stablehlo.select %51, %50, %52 : tensor<16x100xi1>, tensor<16x100xf32>
    %54 = stablehlo.dot_general %31, %53, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
    %55 = stablehlo.add %54, %33 : tensor<3x100xf32>
    %56 = stablehlo.maximum %cst, %55 : tensor<3x100xf32>
    %57 = stablehlo.reduce(%56 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
    %58 = stablehlo.cosine %arg4 : tensor<4xf32>
    %59 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<f32>) -> tensor<4xf32>
    %60 = stablehlo.concatenate %38, %59, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
    %61 = stablehlo.concatenate %37, %58, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
    %62 = stablehlo.multiply %60, %61 : tensor<9xf32>
    %63 = stablehlo.subtract %arg2, %62 : tensor<9xf32>
    %64 = stablehlo.multiply %63, %63 : tensor<9xf32>
    %65 = stablehlo.reduce(%64 init: %cst_3) applies stablehlo.add across dimensions = [0] : (tensor<9xf32>, tensor<f32>) -> tensor<f32>
    %66 = stablehlo.multiply %65, %cst_2 : tensor<f32>
    return %66 : tensor<f32>
  }
}

I don’t see any <1> tensors so it seems like it figured it out somehow