Error when using Enzyme to train a model with frozen layers

Hi,

I am using Julia to solve a 1D poisson equation using a shallow 1 layer PINN, I have successfully implemented a code based on the 2D PINN tutorial on the Lux website. Out of curiosity, I wanted to experiment with freezing the first layer and training only the second. However, I always encounter an error saying:

error: ‘stablehlo.transpose’ op using value defined outside the region
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_sI7QMw/module_000_nUkZ_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:119
ERROR: LoadError: “failed to run pass manager on module”

I have created a notebook with heavily simplified code running on a basic MSELoss function of the data to test out why this is happening, the code is as follows:

This is not the PINN code just a minimal recreation to isolate the problem with training frozen layers

using Lux, Optimisers, Random, Printf, Statistics, MLUtils, Enzyme, Plots, Reactant, LinearAlgebra, Zygote


dev = reactant_device()
cdev = cpu_device()
u_analytical(x) =  -(x^5)/20 + x/20

x0 = 0.0
xn = 1.0
n = 10
bc0=0.0
bcn=0.0
nepochs = 15000
rng = Random.default_rng()
Random.seed!(rng, 1234)
# ---------- Dataset ----------
function createtrainsets(x0::Float64, xn::Float64, n::Int64, u_analytical::Function)
    x = range(x0, xn; length=n)
    u = u_analytical.(x) #.+ randn(rng, Float32, (n, 1)) .* 0.0001
    return collect(x), vec(u)
end

xs, us = createtrainsets(x0, xn, n, u_analytical)
xs = dev(collect(xs'))
us = dev(collect(us'))

pinn = Chain(Dense(1 => 64, tanh),Dense(64 => 1, identity))

function freeze_first_dense(layer, ps, st, name)
    if name == KeyPath(:layer_1)
        return Lux.Experimental.freeze(layer, ps, st, (:weight, :bias))
    else
        return layer, ps, st
    end
end


ps, st = dev(Lux.setup(rng, pinn))
pinn_frozen, ps_frozen, st_frozen = Lux.Experimental.layer_map(freeze_first_dense, pinn, ps, st)
#println(ps_frozen)


function train_model!(model, ps, st, opt, nepochs::Int)
    tstate = Training.TrainState(model, ps, st, opt)
    enzyme_mode = Enzyme.set_runtime_activity(Enzyme.Reverse)
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
    for i in 1:nepochs
        grads, loss, _, tstate = Training.single_train_step!(
            ad, lossfn, (xs, us), tstate
        )
        if i % 100 == 1 || i == nepochs
            @printf "Loss Value after %6d iterations: %.8f\n" i loss
        end
    end
    return tstate.model, tstate.parameters, tstate.states
end

lr = 1e-3
opt = Adam(lr)
lossfn = MSELoss()
trained_pinn, trained_parameters, trained_states = train_model!(pinn_frozen, ps_frozen, st_frozen, opt, nepochs)


x_plot = reshape(range(x0, xn, length=1000), 1, :)
u_true = u_analytical.(x_plot[1, :])
u_pred, _ = trained_pinn(x_plot, trained_parameters, trained_states)

plot(x_plot[1, :], u_true, label="Analytical", lw=2)
plot!(x_plot[1, :], vec(u_pred), label="Model", lw=2, ls=:dash)

The code runs fine when I use Zygote AD (obviously removing the Reactant bits) but I encounter the same error as in the PINN implemetation when running the code using Enzyme AD.

This is the full error message

error: 'stablehlo.transpose' op using value defined outside the region
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_sI7QMw/module_001_UUvS_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:119
ERROR: LoadError: "failed to run pass manager on module"
Stacktrace:
  [1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:163
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1140
  [3] run_pass_pipeline!
    @ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1135 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(LuxReactantExt.compute_gradients_internal_and_step!), args::Tuple{…}, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, optimize::Bool, cudnn_hlo_optimize::Bool, shardy_passes::Symbol, no_nan::Bool, transpose_propagate::Symbol, reshape_propagate::Symbol, optimize_communications::Bool, assert_nonallocating::Bool, backend::String, raise::Bool, raise_first::Bool, donated_args::Symbol, optimize_then_pad::Bool, runtime::Val{…}, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1511
  [5] compile_mlir!
    @ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1319 [inlined]
  [6] compile_xla(f::Function, args::Tuple{…}; client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3223
  [7] compile_xla
    @ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3205 [inlined]
  [8] compile(f::Function, args::Tuple{…}; sync::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3280
  [9] macro expansion
    @ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:2385 [inlined]
 [10] (::LuxReactantExt.var"#6#7"{Lux.Training.ReactantBackend{Static.True}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{Matrix{…}, Matrix{…}}})()
    @ LuxReactantExt ~/.julia/packages/Lux/lRugP/ext/LuxReactantExt/training.jl:107
 [11] with(f::LuxReactantExt.var"#6#7"{…}, pair::Pair{…}, rest::Pair{…})
    @ Base.ScopedValues ./scopedvalues.jl:269
 [12] #with_config#7
    @ ~/.julia/packages/Reactant/iZa5R/src/Configuration.jl:62 [inlined]
 [13] with_config
    @ ~/.julia/packages/Reactant/iZa5R/src/Configuration.jl:34 [inlined]
 [14] single_train_step_impl!(backend::Lux.Training.ReactantBackend{…}, objective_function::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxReactantExt ~/.julia/packages/Lux/lRugP/ext/LuxReactantExt/training.jl:103
 [15] #single_train_step!#6
    @ ~/.julia/packages/Lux/lRugP/src/helpers/training.jl:294 [inlined]
 [16] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/lRugP/src/helpers/training.jl:288
 [17] train_model!(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, opt::Adam{…}, nepochs::Int64)
    @ Main ~/***/freezetesting.jl:48
 [18] top-level scope
    @ ~/***/freezetesting.jl:61
 [19] include(fname::String)
    @ Main ./sysimg.jl:38
 [20] top-level scope
    @ REPL[1]:1
in expression starting at ***/freezetesting.jl:61
Some type information was truncated. Use `show(err)` to see complete types.

I do not want to just swap over to Zygote as I am using nested AD with Enzyme in my PINN code, I want to make it work using Enzyme.

Does anyone else have the same issue?

Yeah this is a known issue (with a fix that is already implemented). We were waiting on a patch in upstream XLA to be merged but now that is done (yesterday). we are progressively merging the downstream PRs.

In short, we should have the Lux
patch merged and tagged latest by tomorrow.

1 Like

Thanks!