Slow Reactant compilation for modified DeepONet

I’m trying to implement the modified Deep Operator Network from this paper: Improved architectures and training algorithms for deep operator networks" (arXiv:2110.01654). The addition over the vanilla DeepONet is a set of encoders that connect the branch and trunk networks. I’ve come up with the implementation below, which looks like it compiles and trains, but the compilation is extremely slow, to the point where even Reactant itself notices this and suggests to file a bug report. I’m wondering if this is an issue with my implementation, or if this just a current limitation of Reactant. Any suggestions for improvements are welcome. Happy to contribute this back to NeuralOperators, if people think it’s useful.

Code:

module DeepONet
using ConcreteStructs: @concrete
using Lux
using LinearAlgebra

export ModifiedDeepONet

@concrete struct ModifiedDeepONet <: AbstractLuxContainerLayer{(:branch, :trunk, :b_encoder, :t_encoder)}
    branch
    trunk
    b_encoder
    t_encoder
end


function ModifiedDeepONet(;
        branch = (64, 32, 32, 16),
        trunk = (1, 32, 32, 16),
        branch_activation = identity,
        trunk_activation = identity,
    )

    # checks for dimension sizes
    # Branch and Trunk net must share the same amount
    # of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ 
    # won't work. Additionally, the modified ModifiedDeepONet
    # needs all of the hidden layers to have the same size,
    # so that the encoders can be multiplied
    @assert length(branch) == length(trunk)
    for i in 2:length(branch)
        @assert branch[i] == trunk[i]
    end

    # Branch Encoder
    U = Dense(branch[1] => branch[2])
    # Trunk Encoder
    V = Dense(trunk[1] => trunk[2])
    branch_net = Chain(
        [
            Dense(
                    branch[i] => branch[i + 1],
                    ifelse(i == length(branch) - 1, identity, branch_activation),
                ) for i in 1:(length(branch) - 1)
        ]...,
    )

    trunk_net = Chain(
        [
            Dense(
                    trunk[i] => trunk[i + 1],
                    ifelse(i == length(trunk) - 1, identity, trunk_activation),
                ) for i in 1:(length(trunk) - 1)
        ]...,
    )
    return ModifiedDeepONet(branch_net, trunk_net, U, V)
end


function (mdon::ModifiedDeepONet)(
    x::Tuple{AbstractArray{T, 2}, AbstractArray{T, 2}}, ps::NamedTuple, st::NamedTuple
) where {T}
    u, v = x
    V = first(mdon.t_encoder(v, ps.t_encoder, st.t_encoder))    
    # now we run the network for each function in u
    
    # output size must be N_points x N_functions, so size(v, 2) x size(u , 2)
    result = zeros(eltype(u), size(v,2), size(u,2))
    for funcIdx in axes(u, 2)
        U = first(mdon.b_encoder(u[:, funcIdx], ps.b_encoder, st.b_encoder))
        # Hidden layer 1
        Hᵥ = first(mdon.trunk.layers[1](v, ps.trunk[1], st.trunk[1]))
        Hᵤ = first(mdon.branch.layers[1](u[:, funcIdx], ps.branch[1], st.branch[1]))
        # Hidden layers 2:end-1
        # Note that we're computing H(L+1) here
        for idx in 2:(length(mdon.branch.layers)-1)
            b_layer = mdon.branch.layers[idx]
            t_layer = mdon.trunk.layers[idx]
            layer_sym = Symbol(:layer_, idx)
            
            Zᵤ = first(b_layer(Hᵤ, ps.branch[layer_sym], st.branch[layer_sym]))
            Zᵥ = first(t_layer(Hᵥ, ps.trunk[layer_sym], st.trunk[layer_sym]))
            Hᵤ = (1 .- Zᵤ) .* U .+ Zᵤ .* V
            Hᵥ = (1 .- Zᵥ) .* U .+ Zᵥ .* V
        end
        # last layer of the branch and trunk networks
        # final layer does not have an activation function
        n_layers = length(mdon.branch.layers)
        last_layer_sym = Symbol(:layer_, n_layers)
        Hᵤ = first(mdon.branch.layers[end](Hᵤ, ps.branch[last_layer_sym], st.branch[last_layer_sym]))
        Hᵥ = first(mdon.trunk.layers[end](Hᵥ, ps.trunk[last_layer_sym], st.trunk[last_layer_sym]))
        result[:, funcIdx] .= vec(sum(Hᵤ .* Hᵥ, dims=1))
    end
    return result, st
end

end # module

Test:

using Lux
include("Modified_DeepONet.jl")
using .DeepONet
using Random
using Optimisers
using Reactant

rng = Random.default_rng()
Random.seed!(rng, 1234)
deeponet = ModifiedDeepONet(;branch = (64, 32, 32, 16),
    trunk = (1, 32, 32, 16),
    branch_activation = tanh,
    trunk_activation = tanh
)

xdev = reactant_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, deeponet) |> xdev
u_input = rand(Float32, 64, 3) |> xdev
y_input = rand(Float32, 1, 10) |> xdev
v_data = rand(Float32, 10, 3) |> xdev
data = [((u_input, y_input), v_data)];

# out = deeponet((u_input, y_input), ps, st)
# println(size(out))

tstate = Training.TrainState(deeponet, ps, st, Adam(0.0001f0))
(_, loss, _, tstate) = Training.single_train_step!(
            AutoEnzyme(), MSELoss(), data[1], tstate; return_gradients=Val(false)
)

can you post the log here?

With pleasure: To be clear, the network I’m compiling in my actual work is a bit bigger than the example:

deeponet = ModifiedDeepONet(;branch = (EAXIS_STEPS, 200, 200, 200, 200, 32),
    trunk = (1, 200, 200, 200, 200, 32),
    branch_activation = tanh,
    trunk_activation = tanh
)

The log:

I0000 00:00:1768764720.915397 1987208 dot_merger.cc:481] Merging Dots in computation: main.1749
I0000 00:00:1768764959.563355 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_25', 616 bytes spill stores, 628 bytes spill loads

I0000 00:00:1768764964.172583 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_multiply_reduce_fusion', 8372 bytes spill stores, 9128 bytes spill loads

I0000 00:00:1768764965.617236 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_28', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764966.801038 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764969.647473 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_8', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764970.829797 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764972.226716 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764974.841714 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_31', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764976.025502 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_32', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764977.203300 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_7', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764978.907805 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_34', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764980.245706 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_36', 216 bytes spill stores, 216 bytes spill loads

I0000 00:00:1768764983.055892 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_9', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764986.295274 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_4', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764990.680155 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_26', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764996.176879 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_5', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764997.349667 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_27', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765007.219362 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_3', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765015.302654 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_11', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765017.512038 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_29', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765023.805989 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_10', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765025.779900 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_33', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765032.997574 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_35', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765034.202515 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_2', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765035.673559 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_30', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765037.431872 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_12', 232 bytes spill stores, 232 bytes spill loads

E0000 00:00:1768765164.109254 1989983 slow_operation_alarm.cc:73] 
********************************
[Compiling module reactant_compute... for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
I0000 00:00:1768765240.079944 1987208 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_25', 616 bytes spill stores, 628 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_multiply_reduce_fusion', 8300 bytes spill stores, 9048 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_2', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_3', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_4', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_5', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_7', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_8', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_9', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_10', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_11', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_12', 232 bytes spill stores, 232 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_36', 216 bytes spill stores, 216 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_35', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_34', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_33', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_32', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_31', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_30', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_29', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_28', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_27', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_26', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_65', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_11', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_30', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_53', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_20', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_34', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_49', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_44', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_25', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_16', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_60', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_39', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_51', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_32', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_27', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_37', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_64', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_13', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_46', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_41', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_23', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_56', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_18', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_59', 12 bytes spill stores, 12 bytes spill loads

E0000 00:00:1768765248.238571 1987208 slow_operation_alarm.cc:140] The operation took 3m24.12941995s

********************************
[Compiling module reactant_compute... for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

Googling around a bit. Could part of the issue be the zeros call to allocate the result array? Is that going to be allocated on the CPU, while the rest of the arrays are on the GPU? Or does the compiler somehow figure this out?

To follow up on my question: Various attempts at explicitly moving it to the device didn’t result in an improvement or simply broke.

@avikpal It looks like there is a regression with the latest Reactant and/or Lux. The compilation now just hangs… Might be useful to add s.th. like this as a regression test. Note that this is the case for a vanilla deeponet from NeuralOperators for me as well.
and it looks like the reason for the hanging is excessive memory use

Can you test v0.2.203 (or newer)? Locally I am not seeing any of these issues. (I did fix an issue with broadcasted store in the latest release so maybe that fixed the issue)

@Jan_Strube since we currently have a deeponet in ci already, this is a bit surprising (Reactant.jl/benchmark/nn/neural_operators.jl at 592c165632e9298bbf7fc5b04659c8f0ce7fac5d · EnzymeAD/Reactant.jl · GitHub)?

Maybe you can make a PR including your test case and it can help us see what’s up (And ensure no future regressions)

I did test with 203, same issues. Testing with 205 now. Will see if the current deeponet test in CI behaves differently from my code.

I think I see what’s going on.
In my case, one of my arrays is a Base.ReshapedArray{Float32, 2, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, while the test case just uses ConcretePJRTArray{Float32, 2, 1}. I can easily avoid this in my own code, but do you want me to file a bug report for this?

So, to clarify:

y_data = reshape(range(1f0, 10f0, step=k_STEP), 1, EVAL_POINTS)
y_data = y_data |> xdev;

results in a Base.ReshapedArray{Float32, 2, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}. Not sure if that’s a change from an earlier version of Reactant, but I’ve always had this version in my code, and it was compiling before, just slowly, which is what caused me to create this thread in the first place.