Slow Reactant compilation for modified DeepONet

I’m trying to implement the modified Deep Operator Network from this paper: Improved architectures and training algorithms for deep operator networks" (arXiv:2110.01654). The addition over the vanilla DeepONet is a set of encoders that connect the branch and trunk networks. I’ve come up with the implementation below, which looks like it compiles and trains, but the compilation is extremely slow, to the point where even Reactant itself notices this and suggests to file a bug report. I’m wondering if this is an issue with my implementation, or if this just a current limitation of Reactant. Any suggestions for improvements are welcome. Happy to contribute this back to NeuralOperators, if people think it’s useful.

Code:

module DeepONet
using ConcreteStructs: @concrete
using Lux
using LinearAlgebra

export ModifiedDeepONet

@concrete struct ModifiedDeepONet <: AbstractLuxContainerLayer{(:branch, :trunk, :b_encoder, :t_encoder)}
    branch
    trunk
    b_encoder
    t_encoder
end


function ModifiedDeepONet(;
        branch = (64, 32, 32, 16),
        trunk = (1, 32, 32, 16),
        branch_activation = identity,
        trunk_activation = identity,
    )

    # checks for dimension sizes
    # Branch and Trunk net must share the same amount
    # of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ 
    # won't work. Additionally, the modified ModifiedDeepONet
    # needs all of the hidden layers to have the same size,
    # so that the encoders can be multiplied
    @assert length(branch) == length(trunk)
    for i in 2:length(branch)
        @assert branch[i] == trunk[i]
    end

    # Branch Encoder
    U = Dense(branch[1] => branch[2])
    # Trunk Encoder
    V = Dense(trunk[1] => trunk[2])
    branch_net = Chain(
        [
            Dense(
                    branch[i] => branch[i + 1],
                    ifelse(i == length(branch) - 1, identity, branch_activation),
                ) for i in 1:(length(branch) - 1)
        ]...,
    )

    trunk_net = Chain(
        [
            Dense(
                    trunk[i] => trunk[i + 1],
                    ifelse(i == length(trunk) - 1, identity, trunk_activation),
                ) for i in 1:(length(trunk) - 1)
        ]...,
    )
    return ModifiedDeepONet(branch_net, trunk_net, U, V)
end


function (mdon::ModifiedDeepONet)(
    x::Tuple{AbstractArray{T, 2}, AbstractArray{T, 2}}, ps::NamedTuple, st::NamedTuple
) where {T}
    u, v = x
    V = first(mdon.t_encoder(v, ps.t_encoder, st.t_encoder))    
    # now we run the network for each function in u
    
    # output size must be N_points x N_functions, so size(v, 2) x size(u , 2)
    result = zeros(eltype(u), size(v,2), size(u,2))
    for funcIdx in axes(u, 2)
        U = first(mdon.b_encoder(u[:, funcIdx], ps.b_encoder, st.b_encoder))
        # Hidden layer 1
        Hᵥ = first(mdon.trunk.layers[1](v, ps.trunk[1], st.trunk[1]))
        Hᵤ = first(mdon.branch.layers[1](u[:, funcIdx], ps.branch[1], st.branch[1]))
        # Hidden layers 2:end-1
        # Note that we're computing H(L+1) here
        for idx in 2:(length(mdon.branch.layers)-1)
            b_layer = mdon.branch.layers[idx]
            t_layer = mdon.trunk.layers[idx]
            layer_sym = Symbol(:layer_, idx)
            
            Zᵤ = first(b_layer(Hᵤ, ps.branch[layer_sym], st.branch[layer_sym]))
            Zᵥ = first(t_layer(Hᵥ, ps.trunk[layer_sym], st.trunk[layer_sym]))
            Hᵤ = (1 .- Zᵤ) .* U .+ Zᵤ .* V
            Hᵥ = (1 .- Zᵥ) .* U .+ Zᵥ .* V
        end
        # last layer of the branch and trunk networks
        # final layer does not have an activation function
        n_layers = length(mdon.branch.layers)
        last_layer_sym = Symbol(:layer_, n_layers)
        Hᵤ = first(mdon.branch.layers[end](Hᵤ, ps.branch[last_layer_sym], st.branch[last_layer_sym]))
        Hᵥ = first(mdon.trunk.layers[end](Hᵥ, ps.trunk[last_layer_sym], st.trunk[last_layer_sym]))
        result[:, funcIdx] .= vec(sum(Hᵤ .* Hᵥ, dims=1))
    end
    return result, st
end

end # module

Test:

using Lux
include("Modified_DeepONet.jl")
using .DeepONet
using Random
using Optimisers
using Reactant

rng = Random.default_rng()
Random.seed!(rng, 1234)
deeponet = ModifiedDeepONet(;branch = (64, 32, 32, 16),
    trunk = (1, 32, 32, 16),
    branch_activation = tanh,
    trunk_activation = tanh
)

xdev = reactant_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, deeponet) |> xdev
u_input = rand(Float32, 64, 3) |> xdev
y_input = rand(Float32, 1, 10) |> xdev
v_data = rand(Float32, 10, 3) |> xdev
data = [((u_input, y_input), v_data)];

# out = deeponet((u_input, y_input), ps, st)
# println(size(out))

tstate = Training.TrainState(deeponet, ps, st, Adam(0.0001f0))
(_, loss, _, tstate) = Training.single_train_step!(
            AutoEnzyme(), MSELoss(), data[1], tstate; return_gradients=Val(false)
)

can you post the log here?

With pleasure: To be clear, the network I’m compiling in my actual work is a bit bigger than the example:

deeponet = ModifiedDeepONet(;branch = (EAXIS_STEPS, 200, 200, 200, 200, 32),
    trunk = (1, 200, 200, 200, 200, 32),
    branch_activation = tanh,
    trunk_activation = tanh
)

The log:

I0000 00:00:1768764720.915397 1987208 dot_merger.cc:481] Merging Dots in computation: main.1749
I0000 00:00:1768764959.563355 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_25', 616 bytes spill stores, 628 bytes spill loads

I0000 00:00:1768764964.172583 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_multiply_reduce_fusion', 8372 bytes spill stores, 9128 bytes spill loads

I0000 00:00:1768764965.617236 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_28', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764966.801038 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764969.647473 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_8', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764970.829797 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764972.226716 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764974.841714 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_31', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764976.025502 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_32', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764977.203300 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_7', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764978.907805 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_34', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764980.245706 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_36', 216 bytes spill stores, 216 bytes spill loads

I0000 00:00:1768764983.055892 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_9', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764986.295274 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_4', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764990.680155 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_26', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764996.176879 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_5', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768764997.349667 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_27', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765007.219362 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_3', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765015.302654 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_11', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765017.512038 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_29', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765023.805989 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_10', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765025.779900 1987470 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_33', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765032.997574 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_35', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765034.202515 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_2', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765035.673559 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_30', 176 bytes spill stores, 176 bytes spill loads

I0000 00:00:1768765037.431872 1987460 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_12', 232 bytes spill stores, 232 bytes spill loads

E0000 00:00:1768765164.109254 1989983 slow_operation_alarm.cc:73] 
********************************
[Compiling module reactant_compute... for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
I0000 00:00:1768765240.079944 1987208 subprocess_compilation.cc:348] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_25', 616 bytes spill stores, 628 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_multiply_reduce_fusion', 8300 bytes spill stores, 9048 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_2', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_3', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_4', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_5', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_7', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_8', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_9', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_10', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_11', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_12', 232 bytes spill stores, 232 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_36', 216 bytes spill stores, 216 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_35', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_34', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_33', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_32', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_31', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_30', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_29', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_28', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_27', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_26', 176 bytes spill stores, 176 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_65', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_11', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_30', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_53', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_20', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_34', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_49', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_44', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_25', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_16', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_60', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_39', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_51', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_32', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_27', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_37', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_64', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_13', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_46', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_41', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_23', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_56', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_18', 12 bytes spill stores, 12 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_tanh_fusion_59', 12 bytes spill stores, 12 bytes spill loads

E0000 00:00:1768765248.238571 1987208 slow_operation_alarm.cc:140] The operation took 3m24.12941995s

********************************
[Compiling module reactant_compute... for GPU] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************