Using Reactant with Lux and Enzyme to speed up training in physics context

I am training a neural network in a physical context on observables; crucially, the network is producing values which are passed through some physics calculations before being compared to the training data. My code works, but has become quite slow as I’ve acquired more training data (on the order of 30,000 points). I would like to try moving it to Reactant to make use of its optimization and also HPC GPUs.

I’m running into a lot of errors that are either “Constant memory is stored (or returned) to a differentiable variable” or “No augmented forward pass found for ejlstr$UnsafeBufferPointer$…”.
I’d like some help dealing with this, or ideas on refactoring my code to be more friendly to Reactant. I’ve created an MWE that matches the high levels of my code quite closely, but leaves out the physics underneath. Here is its current form and the error I’m getting:

begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Reactant
    using Statistics
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end

    function combex(x::AbstractMatrix{T}, r::AbstractVector{T}) where T<:AbstractFloat
        xlen, nx = size(x)
        rlen = length(r)

        # Preallocate output of the same type T
        X = Array{T}(undef, xlen*rlen, nx+1)

        # Fill output manually
        for i in 1:xlen
            row_start = (i-1)*rlen + 1
            row_end   = i*rlen
            for j in 1:rlen
                X[row_start + j - 1, 1] = r[j]           # first column
                @inbounds X[row_start + j - 1, 2:end] = x[i, :]  # rest of columns
            end
        end

        return X' 
    end

    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[3]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return ln_loss(y, dsigs)
    end

    function ln_loss(y_data, y_model)
        ln_data = log.(y_data)
        ln_model = log.(y_model)
        return mean((ln_data - ln_model).^2)
    end

    function calculateDifferentialCrossSection(A, Z, E, U, r, dr, theta, Lrange)
        return rand(Float32, size(theta))
    end

    function calculateMultiDiffCrossSections(p, args)
        x = args[1]
        X = args[2]
        r = args[4]
        dr = args[5]
        thetas = args[6]
        Lrange = args[7]
        M = args[8]
        st = args[9]
        nlen = size(x, 1)
        rlen = size(r,1)
        datalen = 0
        for i in 1:nlen
            datalen += length(thetas[i])
        end
        dσ = zeros(eltype(x), datalen)
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            sig = zeros(eltype(x), exp_len)
            j_next = j + exp_len
            if x[i,2] > 0
                U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
                # U = eval_model(M, p, st, view(X,:,(i-1)*rlen+1 : i*rlen))
                sig = calculateDifferentialCrossSection(x[i,2], x[i,3], x[i,1], U, r, dr, thetas[i], Lrange)
            else
                # Routine for adding isotopes of a natural element
                nat_inds = findall(x -> x == x[:,3], nat_zs)
                nat_abund = nat_abunds[nat_inds]
                nat_As = mod.(nat_zaids[nat_inds], 1000)
                k = 1
                for nat_A in nat_As
                    X[3,(i-1)*rlen+1 : i*rlen] .= nat_A
                    U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
                    sig += nat_abund[k]/100*calculateDifferentialCrossSection(nat_A, x[i,3], x[i,1], U, r, dr, thetas[i], Lrange)
                    k+=1
                end
            end
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end


const cdev = cpu_device()
const xdev = reactant_device()

# Set up physics parameters
data_type = Float32
Lmax = 15
dr = 0.1
rmin = dr/100
rmax = 12.5
r = Vector{data_type}(rmin:dr:rmax)
Lrange = collect(0:Lmax)

# Abundances
data_abund = [12024 78.99;
    12025 10;
    12026 11.01;
    14028 92.223;
    14029 4.685;
    14030 3.092]

const global nat_zaids = Vector{Int}(data_abund[:,1])
const global nat_abunds = Vector{data_type}(data_abund[:,2])
const global nat_zs = nat_zaids .÷ 1000

# Generate dummy data
x_train = rand(data_type, 2,3)
X_train = combex(x_train, r)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
XSdiff_train = rand(data_type, 9)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st

args = (xdev(x_train), xdev(X_train), xdev(XSdiff_train), r, dr, theta_train, Lrange, model, _st)

# Test loss function evaluation
losstest = lossDiff(p, args)
Enzyme.jacobian(Reverse, p -> lossDiff(p, args),p)
losstest

The current error:

No augmented forward pass found for ejlstr$UnsafeBufferPointer$/Users/daningburg/.julia/artifacts/2c69783b22c1072452c6b137cf11806ce31f9f67/lib/libReactantExtra.dylib
 at context:   %39 = call i64 @"ejlstr$UnsafeBufferPointer$/Users/daningburg/.julia/artifacts/2c69783b22c1072452c6b137cf11806ce31f9f67/lib/libReactantExtra.dylib"(i64 %38) #271, !dbg !325


Stacktrace:
  [1] wait
    @ ~/.julia/packages/Reactant/QTNFa/src/Types.jl:195
  [2] setindex!
    @ ~/.julia/packages/Reactant/QTNFa/src/ConcreteRArray.jl:359
  [3] macro expansion
    @ ./cartesian.jl:62 [inlined]
  [4] _unsafe_getindex!
    @ ./multidimensional.jl:938 [inlined]
  [5] _unsafe_getindex
    @ ./multidimensional.jl:929
  [6] Array
    @ ./boot.jl:579 [inlined]
  [7] Array
    @ ./boot.jl:591 [inlined]
  [8] zeros
    @ ./array.jl:589 [inlined]
  [9] zeros
    @ ./array.jl:585 [inlined]
 [10] calculateMultiDiffCrossSections
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:98
 [11] augmented_julia_lossDiff_40511_inner_217wrap
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:0
 [12] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5691 [inlined]
 [13] enzyme_call
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5225 [inlined]
 [14] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5164 [inlined]
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/rules/jitrules.jl:447 [inlined]
 [16] runtime_generic_augfwd(::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::typeof(lossDiff), ::Nothing, ::ComponentVector{…}, ::ComponentVector{…}, ::Tuple{…}, ::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/iosr4/src/rules/jitrules.jl:574
 [17] #57
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170 [inlined]
 [18] augmented_julia__57_29225_inner_1wrap
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:0
 [19] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5691 [inlined]
 [20] enzyme_call
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5225 [inlined]
 [21] AugmentedForwardThunk
    @ ~/.julia/packages/Enzyme/iosr4/src/compiler.jl:5164 [inlined]
 [22] autodiff
    @ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:408 [inlined]
 [23] autodiff
    @ ~/.julia/packages/Enzyme/iosr4/src/Enzyme.jl:538 [inlined]
 [24] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:324 [inlined]
 [25] gradient
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:262 [inlined]
 [26] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:861 [inlined]
 [27] jacobian_helper
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:785 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1239 [inlined]
 [29] jacobian(mode::ReverseMode{false, false, false, FFIABI, false, false}, f::var"#57#58", xs::ComponentVector{Float32, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}}, Tuple{Axis{…}}})
    @ Enzyme ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213
 [30] top-level scope
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170
 [31] include(fname::String)
    @ Main ./sysimg.jl:38
 [32] top-level scope
    @ REPL[6]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactant_and_lux.jl:170

A reactant function must be @compile before run ie

f(x) = sum(x)
g(x) = Enzyme.gradient(Reverse,f,x)
x = Reactant.to_rarray(rand(10))
g_comp = @compile g(x)
g_comp(x)

note that a lot of things happen in your code, if something else goes wrong try to minimize it and issue it.
ps : if you will only run the function once you can do

f(x) = sum(x)
g(x) = Enzyme.gradient(Reverse,f,x)
x = Reactant.to_rarray(rand(10))
@jit g(x)

I’ve written a simpler version of the above code here:

begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Reactant
    using Statistics
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        return sum(U)*cos.(theta)
    end

    function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = 0
        for i in 1:nlen
            datalen += length(thetas[i])
        end
        dσ = zeros(eltype(X), datalen)
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            sig = zeros(eltype(X), exp_len)
            j_next = j + exp_len
            U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
            sig = calculateDifferentialCrossSection(U, thetas[i])
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end
end


const xdev = reactant_device()
data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st

args = (xdev(X_train), xdev(XSdiff_train), xdev(theta_train), model, xdev(_st))

# Test loss function evaluation
# losstest = lossDiff(p, args)
# losstest
dl_dp(p) = Enzyme.jacobian(Reverse, p -> lossDiff(p, args),p)
dldp_comp = @compile dl_dp(p)
dldp_comp(p)

This returns a StackOverflow error:

LoadError: StackOverflowError:
Stacktrace:
     [1] getindex
       @ ~/.julia/packages/Reactant/QTNFa/src/ConcreteRArray.jl:316 [inlined]
     [2] (::Nothing)(none::typeof(getindex), none::ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, none::Tuple{Function, UnitRange{Int64}})
       @ Reactant ./<missing>:0
--- the above 2 lines are repeated 8019 more times ---
 [16041] calculateMultiDiffCrossSections
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:71 [inlined]
 [16042] (::Nothing)(none::typeof(calculateMultiDiffCrossSections), none::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, none::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
       @ Reactant ./<missing>:0
 [16043] getindex
       @ ./tuple.jl:31 [inlined]
 [16044] calculateMultiDiffCrossSections
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:55 [inlined]
 [16045] call_with_reactant(::Reactant.MustThrowError, ::typeof(calculateMultiDiffCrossSections), ::ComponentVector{…}, ::Tuple{…})
       @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [16046] lossDiff
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:46 [inlined]
 [16047] (::Nothing)(none::typeof(lossDiff), none::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, none::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
       @ Reactant ./<missing>:0
 [16048] lossDiff
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:46 [inlined]
 [16049] call_with_reactant(::typeof(lossDiff), ::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, ::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
       @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [16050] #6
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [16051] (::Nothing)(none::var"#6#7", none::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
       @ Reactant ./<missing>:0
 [16052] #6
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [16053] call_with_reactant(::var"#6#7", ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
       @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [16054] macro expansion
       @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:850 [inlined]
 [16055] jacobian_helper
       @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:785 [inlined]
 [16056] macro expansion
       @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1239 [inlined]
 [16057] jacobian
       @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213 [inlined]
 [16058] (::Nothing)(none::typeof(jacobian), none::ReverseMode{false, false, false, FFIABI, false, false}, none::var"#6#7", none::Tuple{ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}}})
       @ Reactant ./<missing>:0
 [16059] jacobian
       @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213 [inlined]
 [16060] call_with_reactant(::typeof(jacobian), ::ReverseMode{false, false, false, FFIABI, false, false}, ::var"#6#7", ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
       @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [16061] dl_dp
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [16062] (::Nothing)(none::typeof(dl_dp), none::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
       @ Reactant ./<missing>:0
 [16063] dl_dp
       @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [16064] call_with_reactant(::typeof(dl_dp), ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
       @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [16065] make_mlir_fn(f::typeof(dl_dp), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
       @ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:332
 [16066] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
       @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1555
 [16067] compile_mlir! (repeats 2 times)
       @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1522 [inlined]
 [16068] compile_xla(f::Function, args::Tuple{ComponentVector{Float32, ConcretePJRTArray{…}, Tuple{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
       @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3433
 [16069] compile_xla
       @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3406 [inlined]
 [16070] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
       @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3505
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:106
Some type information was truncated. Use `show(err)` to see complete types.

Away from my computer to check this but by default Reactant will unroll any for loop which could cause that stack overflow. To fix that issue you need to use the @trace macro, which currently has some tough edges (Julia coding patterns the map to XLA/scan resp. StableHLO/while? · Issue #1598 · EnzymeAD/Reactant.jl · GitHub) but it may work for you!

Edit: although looking at it more closely those aren’t massive loops so unrolling isn’t a massive issue.

I updated that function to include @trace:

function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = 0
        for i in 1:nlen
            datalen += length(thetas[i])
        end
        dσ = zeros(eltype(X), datalen)
        j = 1
        @trace for i in range(1,nlen)
            exp_len = length(thetas[i])
            sig = zeros(eltype(X), exp_len)
            j_next = j + exp_len
            U = eval_model(M, p, st, X[:,(i-1)*rlen+1 : i*rlen])
            sig = calculateDifferentialCrossSection(U, thetas[i])
            dσ[j:j_next-1] = sig
            j = j_next
        end
        return dσ
    end

and I get the following:

LoadError: "Unsupported mode: NoStopTracedTrack"
Stacktrace:
  [1] traced_type_inner(T::Type{<:ConcretePJRTArray}, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:317
  [2] traced_type_inner(A::Type{<:Array}, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:520
  [3] traced_type_inner(PT::Type{Base.RefValue{Vector{ConcretePJRTArray{Float64, 1, 1, Reactant.Sharding.ShardInfo{…}}}}}, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:565
  [4] traced_type(T::Type, ::Val{Reactant.NoStopTracedTrack}, track_numbers::Type, sharding::Reactant.Sharding.NoSharding, runtime::Nothing)
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:870
  [5] make_tracer_unknown(seen::Reactant.OrderedIdDict{Any, Any}, prev::Any, path::Any, mode::Reactant.TraceMode; track_numbers::Type, sharding::Any, runtime::Any, kwargs::@Kwargs{})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:1032
  [6] make_tracer_unknown
    @ ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:1009 [inlined]
  [7] #make_tracer#131
    @ ~/.julia/packages/Reactant/QTNFa/src/Tracing.jl:1146 [inlined]
  [8] while_loop(::var"#15#19", ::var"#16#20", ::Base.RefValue{Reactant.TracedRNumber{Int64}}, ::Vararg{Any}; track_numbers::Type, verify_arg_names::NTuple{16, Symbol}, checkpointing::Bool, mincut::Bool, location::Reactant.MLIR.IR.Location)
    @ Reactant.Ops ~/.julia/packages/Reactant/QTNFa/src/Ops.jl:2039
  [9] #traced_while#127
    @ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:20 [inlined]
 [10] traced_while
    @ ~/.julia/packages/Reactant/QTNFa/src/ControlFlow.jl:11 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/ReactantCore/9hY4Z/src/ReactantCore.jl:260 [inlined]
 [12] calculateMultiDiffCrossSections
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:67 [inlined]
 [13] (::Nothing)(none::typeof(calculateMultiDiffCrossSections), none::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, none::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
    @ Reactant ./<missing>:0
 [14] getindex
    @ ./tuple.jl:31 [inlined]
 [15] calculateMultiDiffCrossSections
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:55 [inlined]
 [16] call_with_reactant(::typeof(calculateMultiDiffCrossSections), ::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, ::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [17] lossDiff
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:46 [inlined]
 [18] (::Nothing)(none::typeof(lossDiff), none::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, none::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
    @ Reactant ./<missing>:0
 [19] getindex
    @ ./tuple.jl:31 [inlined]
 [20] lossDiff
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:45 [inlined]
 [21] call_with_reactant(::typeof(lossDiff), ::ComponentVector{Reactant.TracedRNumber{…}, Reactant.TracedRArray{…}, Tuple{…}}, ::Tuple{ConcretePJRTArray{…}, ConcretePJRTArray{…}, Vector{…}, Chain{…}, @NamedTuple{…}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [22] #21
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [23] (::Nothing)(none::var"#21#22", none::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
    @ Reactant ./<missing>:0
 [24] #21
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [25] call_with_reactant(::var"#21#22", ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [26] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:850 [inlined]
 [27] jacobian_helper
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:785 [inlined]
 [28] macro expansion
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1239 [inlined]
 [29] jacobian
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213 [inlined]
 [30] (::Nothing)(none::typeof(jacobian), none::ReverseMode{false, false, false, FFIABI, false, false}, none::var"#21#22", none::Tuple{ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}}})
    @ Reactant ./<missing>:0
 [31] jacobian
    @ ~/.julia/packages/Enzyme/iosr4/src/sugar.jl:1213 [inlined]
 [32] call_with_reactant(::typeof(jacobian), ::ReverseMode{false, false, false, FFIABI, false, false}, ::var"#21#22", ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [33] dl_dp
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [34] (::Nothing)(none::typeof(dl_dp), none::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
    @ Reactant ./<missing>:0
 [35] dl_dp
    @ ~/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:105 [inlined]
 [36] call_with_reactant(::typeof(dl_dp), ::ComponentVector{Reactant.TracedRNumber{Float32}, Reactant.TracedRArray{Float32, 1}, Tuple{Axis{…}}})
    @ Reactant ~/.julia/packages/Reactant/QTNFa/src/utils.jl:0
 [37] make_mlir_fn(f::typeof(dl_dp), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/QTNFa/src/TracedUtils.jl:332
 [38] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1555
 [39] compile_mlir! (repeats 2 times)
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:1522 [inlined]
 [40] compile_xla(f::Function, args::Tuple{ComponentVector{Float32, ConcretePJRTArray{…}, Tuple{…}}}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{compile_options::CompileOptions, fn_kwargs::@NamedTuple{}})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3433
 [41] compile_xla
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3406 [inlined]
 [42] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:3505
 [43] top-level scope
    @ ~/.julia/packages/Reactant/QTNFa/src/Compiler.jl:2586
 [44] include(fname::String)
    @ Main ./sysimg.jl:38
 [45] top-level scope
    @ REPL[1]:1
in expression starting at /Users/daningburg/Documents/Code/nuclear-diffprog/MWEs/reactlux2.jl:106
Some type information was truncated. Use `show(err)` to see complete types.

I made it work on cpu :

begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Reactant
    using Statistics
    Reactant.set_default_backend("cpu")
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        return sum(U)*cos.(theta)
    end

    function calculateMultiDiffCrossSections(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            j_next = j + exp_len
            U = eval_model(M, p, st, @view X[:,(i-1)*rlen+1 : i*rlen])
            sig = sum(U)*cos.(thetas[i])
            dσ[j:j_next-1] .= sig
            j = j_next
        end
        return dσ
    end
end


const xdev = reactant_device()
data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st

args = (xdev(X_train), xdev(XSdiff_train), xdev(theta_train), model, xdev(_st))

# Test loss function evaluation
# losstest = lossDiff(p, args)
# losstest

display(@allowscalar @jit lossDiff(p, args))
dl_dp(p,args) = Enzyme.gradient(Reverse, lossDiff,p,Const(args))
@allowscalar dldp_comp = @compile dl_dp(p,args)
@allowscalar res = dldp_comp(p,args)
display(res)

it should also work on gpu however its a lot of scalar indexing going on so you should vectorize it a bit more first or write KernelAbstraction.jl kernels.

PS : if someone from Reactant.jl team (@wsmoses for instance) comes by there is a method ambiguity with fill!(::Vector{Reactant.TracedRNumber{Float32}}, ::Reactant.TracedRNumber{Float32}) which appears when doing
zeros(eltype(X), datalen) which forced me to do [zero(eltype(X)) for _ in 1:datalen] . Maybe thats wanted but then if I just make a f64/f32 array and send it to xla its fill with ConcreteNumbers and can’t be filled with TracedNumbers leading to

LoadError: MethodError: no method matching _copyto!(::SubArray{Float32, 1, ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, Tuple{UnitRange{Int64}}, false}, ::Base.Broadcast.Broadcasted{Reactant.TracedRArrayOverrides.AbstractReactantArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Reactant.TracedRArray{Float64, 1}}})

. So Is there a preffered way to make a temp array that will be fill with traced data for now ?
The method I used if fine but it is a julia array of TracedNumber which is not ideal I think

1 Like

This is great, thanks! I’ve been wanting to vectorize the code, but I’ve had a hard time figuring out how. I have experiments of different lengths, each of which is self contained (hence “calculateMulti…”). To make matters worse, for some experiments i will need to evaluate the model multiple times and average the results. Would you mind showing me an example of vectorizing the loop in calculateMultiDiff.. so I can try and apply it to the rest of my code?

I tried :

    function calculateMultiDiffCrossSections(p, args)
        X,_,thetas,M,st = args
        nlen = 2
        rlen = 100
        exp_len = length.(thetas)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        jA = accumulate(+,exp_len; init=1)
        U = eval_model(M, p, st, @view X[:,1:rlen*nlen])
        sig = mapreduce(vcat,enumerate(thetas)) do (i,theta)
            sum(@view(U[:,(i-1)*rlen+1 : i*rlen])) .* cos.(theta) 
        end
        dσ .= sig
        return dσ
    end
end

edit : seems like the mapreduce doesn’t hit the gpu implementation :cry:
also not sure dσ .= sig is true

Very interesting, thank you. It felt like the vectorized code ran faster, but I checked with @btime and it apparently doesn’t. Should I expect a big speedup if I had large arrays? obviously this is a tiny test case. If I do get significant speedups by vectorizing, it still achieves my goal even if it doesn’t work on GPU.

begin
    using Lux
    using Enzyme
    using ComponentArrays
    using Random
    using Reactant
    using Statistics
    using BenchmarkTools
    Reactant.set_default_backend("cpu")
    Random.seed!(1234)
end

begin # Functions
    function build_model(n_in, n_out, n_layers, n_nodes;
                        act_fun=leakyrelu, last_fun=relu)
        # Input layer
        first_layer = Lux.Dense(n_in, n_nodes, act_fun)

        # Hidden block
        hidden = (Lux.Dense(n_nodes => n_nodes, act_fun) for _ in 1:n_layers)

        # Output layer
        last_layer = Lux.Dense(n_nodes => n_out, last_fun)

        return Chain(first_layer, hidden..., last_layer)
    end

    function eval_model(m, params, st, x)
        M = first(m(x, params, st))
        return M
    end


    function recursive_convert(T, x)
        if x isa AbstractArray
            return convert.(T, x)  # elementwise convert
        elseif x isa NamedTuple
            return NamedTuple{keys(x)}(recursive_convert(T, v) for v in values(x))
        else
            return x
        end
    end

    function lossDiff(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSections(p, args)
        return mean((y - dsigs).^2)
    end

    function lossDiffVec(p, args)
        y = args[2]
        dsigs = calculateMultiDiffCrossSectionsVec(p, args)
        return mean((y - dsigs).^2)
    end

    function calculateDifferentialCrossSection(U, theta)
        return sum(U)*cos.(theta)
    end

    function calculateMultiDiffCrossSectionsVec(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            j_next = j + exp_len
            U = eval_model(M, p, st, @view X[:,(i-1)*rlen+1 : i*rlen])
            sig = sum(U)*cos.(thetas[i])
            dσ[j:j_next-1] .= sig
            j = j_next
        end
        return dσ
    end

    function calculateMultiDiffCrossSectionsVec(p, args)
        X = args[1]
        thetas = args[3]
        M = args[4]
        st = args[5]
        nlen = 2
        rlen = 100
        exp_len = length.(thetas)
        datalen = sum(length.(@view thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        jA = accumulate(+,exp_len; init=1)
        U = eval_model(M, p, st, @view X[:,1:rlen*nlen])
        sig = mapreduce(vcat,enumerate(thetas)) do (i,theta)
            sum(@view(U[:,(i-1)*rlen+1 : i*rlen])) .* cos.(theta) 
        end
        dσ .= sig
        return dσ
    end
end


const xdev = reactant_device()
data_type = Float32

# Generate dummy data
XSdiff_train = rand(data_type, 9)
theta_train = [
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
]
X_train = rand(data_type, 4,200)

# Load a model
nlayers = 2
nnodes = 16
model = build_model(4, 3, nlayers, nnodes)
ps, st = f32(Lux.setup(Random.default_rng(), model))
p = xdev(ComponentArray(recursive_convert(data_type, ps)))
const _st = st

args = (xdev(X_train), xdev(XSdiff_train), xdev(theta_train), model, xdev(_st))

# Test loss function evaluation
# losstest = lossDiff(p, args)
# losstest

# Scalar code
# display(@allowscalar @jit lossDiff(p, args))
dl_dp(p,args) = Enzyme.gradient(Reverse, lossDiff,p,Const(args))
@allowscalar dldp_comp = @compile dl_dp(p,args)
@btime @allowscalar res = dldp_comp(p,args)
# display(res)

# Vectorized code
# display(@allowscalar @jit lossDiff(p, args))
dl_dp_vec(p,args) = Enzyme.gradient(Reverse, lossDiffVec,p,Const(args))
@allowscalar dldp_comp_vec = @compile dl_dp_vec(p,args)
@btime @allowscalar res = dldp_comp_vec(p,args)
# display(res)

times:

  6.050 μs (17 allocations: 448 bytes)
  6.385 μs (17 allocations: 448 bytes)

Edit: I also find it interesting that I still have to allow scalar indexing explicitly for the vectorized version, on account of this line sig = mapreduce(vcat,enumerate(thetas)) do (i,theta)

If you can open an issue for the ambiguity since that shuldn’t ever happen.

That said, I think the correct thing to do would be writing zeros(X, datalen), which will then have it directly create a reactant array, instead of having a base julia array.

Scalar indexing is obviously bad, but we have optimizations that will get rid of a lot of the bad consequences of these (though still in those cases the scalar indexing might be bad for compile time).

Reactant tries to figure out how to vectorize/optimize/etc it under the hood during the compile step, so it’s possible it found something still somewhat decent.

If you look at the @code_hlo dl_dp(p,args) and @code_hlo dl_dp_vec(p,args) you can see what it compiled them into

Ok I will thanks, doesn’t seem define yet

ERROR: LoadError: MethodError: no method matching zero(::Reactant.TracedRArray{Float32, 2}, ::Int64)
The function `zero` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  zero(::Type{Union{}}, ::Any...)
   @ Base number.jl:310
  zero(::AbstractArray{T}) where T<:Number
   @ Base abstractarray.jl:1196
  zero(::AbstractArray{S}) where S<:Union{Missing, Number}
   @ Base abstractarray.jl:1197

yeah ok will theta be very big one day ?, otherwize you could go with a Tuple of xla arrays and keep the old non-vectorize function

theta_train = (
    [1.0, 10.0, 20.0, 45.0, 60],
    [5.0, 15.0, 25.0, 60]
)

and do

args = (xdev(X_train), xdev(XSdiff_train), xdev.(theta_train), model, xdev(_st))

you will need to remove a view somewhere

function calculateMultiDiffCrossSections(p, args)
        X,_,thetas,M,st = args
        nlen = 2
        rlen = 100
        datalen = sum(length.(thetas[1:nlen]))
        dσ = [zero(eltype(X)) for _ in 1:datalen] 
        j = 1
        for i in 1:nlen 
            exp_len = length(thetas[i])
            j_next = j + exp_len
            U = eval_model(M, p, st, @view X[:,(i-1)*rlen+1 : i*rlen])
            sig = sum(U)*cos.(thetas[i])
            @allowscalar @views dσ[j:j_next-1] .= sig
            j = j_next
        end
        return dσ
    end

edit: if theta will be big you can also do a vector of xla arrays
edit 2 : if nlen becomes big, you want to go back to the vectorized form because the loop seems untraceble

oh apparently zero only exists for something with the same size (even for a regular julia array):

julia> x = [3]
1-element Vector{Int64}:
 3

julia> zero(x)
1-element Vector{Int64}:
 0

julia> zero(x, 3)
ERROR: MethodError: no method matching zero(::Vector{Int64}, ::Int64)

well we definitely have similar, and fill defined, so I guess use that [unless the size is the same then regular zero ought work]

Oh ok yeah this works

        dσ = similar(X,datalen) #[zero(eltype(X)) for _ in 1:datalen] 
        fill!(dσ, zero(eltype(X)))

but when indexed I get the awesome error

LoadError: BoundsError: attempt to access 9-element Reactant.TracedRArray{Float32, 1} at index [1, 2, 3, 4, 5]

We all love this one

zero returns the neutral element of + for the input object. Maybe @yolhan_mannes meant to use zeros?

no the issue is zeros(…) always return a CPU julia array so we are creating a CPU array of XPU numbers which is really dirty, so we indeed want something like zero(X) that keeps the same meomry layout but fix elements to 0.

note : its actually pretty good at removing our mess though

  func.func @main(%arg0: tensor<675xf32>, %arg1: tensor<200x4xf32>, %arg2: tensor<9xf32>, %arg3: tensor<5xf32>, %arg4: tensor<4xf32>) -> tensor<f32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<3x100xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<16x100xf32>
    %cst_1 = stablehlo.constant dense<0.00999999977> : tensor<16x100xf32>
    %cst_2 = stablehlo.constant dense<0.111111112> : tensor<f32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.slice %arg1 [100:200, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
    %1 = stablehlo.slice %arg1 [0:100, 0:4] : (tensor<200x4xf32>) -> tensor<100x4xf32>
    %2 = stablehlo.slice %arg0 [64:80] : (tensor<675xf32>) -> tensor<16xf32>
    %3 = stablehlo.slice %arg0 [0:64] : (tensor<675xf32>) -> tensor<64xf32>
    %4 = stablehlo.reshape %3 : (tensor<64xf32>) -> tensor<4x16xf32>
    %5 = stablehlo.dot_general %4, %1, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
    %6 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %7 = stablehlo.add %5, %6 : tensor<16x100xf32>
    %8 = stablehlo.compare  GT, %7, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %9 = stablehlo.multiply %cst_1, %7 : tensor<16x100xf32>
    %10 = stablehlo.select %8, %7, %9 : tensor<16x100xi1>, tensor<16x100xf32>
    %11 = stablehlo.slice %arg0 [336:352] : (tensor<675xf32>) -> tensor<16xf32>
    %12 = stablehlo.slice %arg0 [80:336] : (tensor<675xf32>) -> tensor<256xf32>
    %13 = stablehlo.reshape %12 : (tensor<256xf32>) -> tensor<16x16xf32>
    %14 = stablehlo.dot_general %13, %10, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %15 = stablehlo.broadcast_in_dim %11, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %16 = stablehlo.add %14, %15 : tensor<16x100xf32>
    %17 = stablehlo.compare  GT, %16, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %18 = stablehlo.multiply %cst_1, %16 : tensor<16x100xf32>
    %19 = stablehlo.select %17, %16, %18 : tensor<16x100xi1>, tensor<16x100xf32>
    %20 = stablehlo.slice %arg0 [608:624] : (tensor<675xf32>) -> tensor<16xf32>
    %21 = stablehlo.slice %arg0 [352:608] : (tensor<675xf32>) -> tensor<256xf32>
    %22 = stablehlo.reshape %21 : (tensor<256xf32>) -> tensor<16x16xf32>
    %23 = stablehlo.dot_general %22, %19, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %24 = stablehlo.broadcast_in_dim %20, dims = [0] : (tensor<16xf32>) -> tensor<16x100xf32>
    %25 = stablehlo.add %23, %24 : tensor<16x100xf32>
    %26 = stablehlo.compare  GT, %25, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %27 = stablehlo.multiply %cst_1, %25 : tensor<16x100xf32>
    %28 = stablehlo.select %26, %25, %27 : tensor<16x100xi1>, tensor<16x100xf32>
    %29 = stablehlo.slice %arg0 [672:675] : (tensor<675xf32>) -> tensor<3xf32>
    %30 = stablehlo.slice %arg0 [624:672] : (tensor<675xf32>) -> tensor<48xf32>
    %31 = stablehlo.reshape %30 : (tensor<48xf32>) -> tensor<16x3xf32>
    %32 = stablehlo.dot_general %31, %28, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
    %33 = stablehlo.broadcast_in_dim %29, dims = [0] : (tensor<3xf32>) -> tensor<3x100xf32>
    %34 = stablehlo.add %32, %33 : tensor<3x100xf32>
    %35 = stablehlo.maximum %cst, %34 : tensor<3x100xf32>
    %36 = stablehlo.reduce(%35 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
    %37 = stablehlo.cosine %arg3 : tensor<5xf32>
    %38 = stablehlo.broadcast_in_dim %36, dims = [] : (tensor<f32>) -> tensor<5xf32>
    %39 = stablehlo.dot_general %4, %0, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<4x16xf32>, tensor<100x4xf32>) -> tensor<16x100xf32>
    %40 = stablehlo.add %39, %6 : tensor<16x100xf32>
    %41 = stablehlo.compare  GT, %40, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %42 = stablehlo.multiply %cst_1, %40 : tensor<16x100xf32>
    %43 = stablehlo.select %41, %40, %42 : tensor<16x100xi1>, tensor<16x100xf32>
    %44 = stablehlo.dot_general %13, %43, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %45 = stablehlo.add %44, %15 : tensor<16x100xf32>
    %46 = stablehlo.compare  GT, %45, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %47 = stablehlo.multiply %cst_1, %45 : tensor<16x100xf32>
    %48 = stablehlo.select %46, %45, %47 : tensor<16x100xi1>, tensor<16x100xf32>
    %49 = stablehlo.dot_general %22, %48, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x16xf32>, tensor<16x100xf32>) -> tensor<16x100xf32>
    %50 = stablehlo.add %49, %24 : tensor<16x100xf32>
    %51 = stablehlo.compare  GT, %50, %cst_0 : (tensor<16x100xf32>, tensor<16x100xf32>) -> tensor<16x100xi1>
    %52 = stablehlo.multiply %cst_1, %50 : tensor<16x100xf32>
    %53 = stablehlo.select %51, %50, %52 : tensor<16x100xi1>, tensor<16x100xf32>
    %54 = stablehlo.dot_general %31, %53, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<16x3xf32>, tensor<16x100xf32>) -> tensor<3x100xf32>
    %55 = stablehlo.add %54, %33 : tensor<3x100xf32>
    %56 = stablehlo.maximum %cst, %55 : tensor<3x100xf32>
    %57 = stablehlo.reduce(%56 init: %cst_3) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x100xf32>, tensor<f32>) -> tensor<f32>
    %58 = stablehlo.cosine %arg4 : tensor<4xf32>
    %59 = stablehlo.broadcast_in_dim %57, dims = [] : (tensor<f32>) -> tensor<4xf32>
    %60 = stablehlo.concatenate %38, %59, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
    %61 = stablehlo.concatenate %37, %58, dim = 0 : (tensor<5xf32>, tensor<4xf32>) -> tensor<9xf32>
    %62 = stablehlo.multiply %60, %61 : tensor<9xf32>
    %63 = stablehlo.subtract %arg2, %62 : tensor<9xf32>
    %64 = stablehlo.multiply %63, %63 : tensor<9xf32>
    %65 = stablehlo.reduce(%64 init: %cst_3) applies stablehlo.add across dimensions = [0] : (tensor<9xf32>, tensor<f32>) -> tensor<f32>
    %66 = stablehlo.multiply %65, %cst_2 : tensor<f32>
    return %66 : tensor<f32>
  }
}

I don’t see any <1> tensors so it seems like it figured it out somehow