Hi,
I am using Julia to solve a 1D poisson equation using a shallow 1 layer PINN, I have successfully implemented a code based on the 2D PINN tutorial on the Lux website. Out of curiosity, I wanted to experiment with freezing the first layer and training only the second. However, I always encounter an error saying:
error: ‘stablehlo.transpose’ op using value defined outside the region
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_sI7QMw/module_000_nUkZ_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:119
ERROR: LoadError: “failed to run pass manager on module”
I have created a notebook with heavily simplified code running on a basic MSELoss function of the data to test out why this is happening, the code is as follows:
This is not the PINN code just a minimal recreation to isolate the problem with training frozen layers
using Lux, Optimisers, Random, Printf, Statistics, MLUtils, Enzyme, Plots, Reactant, LinearAlgebra, Zygote
dev = reactant_device()
cdev = cpu_device()
u_analytical(x) = -(x^5)/20 + x/20
x0 = 0.0
xn = 1.0
n = 10
bc0=0.0
bcn=0.0
nepochs = 15000
rng = Random.default_rng()
Random.seed!(rng, 1234)
# ---------- Dataset ----------
function createtrainsets(x0::Float64, xn::Float64, n::Int64, u_analytical::Function)
x = range(x0, xn; length=n)
u = u_analytical.(x) #.+ randn(rng, Float32, (n, 1)) .* 0.0001
return collect(x), vec(u)
end
xs, us = createtrainsets(x0, xn, n, u_analytical)
xs = dev(collect(xs'))
us = dev(collect(us'))
pinn = Chain(Dense(1 => 64, tanh),Dense(64 => 1, identity))
function freeze_first_dense(layer, ps, st, name)
if name == KeyPath(:layer_1)
return Lux.Experimental.freeze(layer, ps, st, (:weight, :bias))
else
return layer, ps, st
end
end
ps, st = dev(Lux.setup(rng, pinn))
pinn_frozen, ps_frozen, st_frozen = Lux.Experimental.layer_map(freeze_first_dense, pinn, ps, st)
#println(ps_frozen)
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Training.TrainState(model, ps, st, opt)
enzyme_mode = Enzyme.set_runtime_activity(Enzyme.Reverse)
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for i in 1:nepochs
grads, loss, _, tstate = Training.single_train_step!(
ad, lossfn, (xs, us), tstate
)
if i % 100 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
end
end
return tstate.model, tstate.parameters, tstate.states
end
lr = 1e-3
opt = Adam(lr)
lossfn = MSELoss()
trained_pinn, trained_parameters, trained_states = train_model!(pinn_frozen, ps_frozen, st_frozen, opt, nepochs)
x_plot = reshape(range(x0, xn, length=1000), 1, :)
u_true = u_analytical.(x_plot[1, :])
u_pred, _ = trained_pinn(x_plot, trained_parameters, trained_states)
plot(x_plot[1, :], u_true, label="Analytical", lw=2)
plot!(x_plot[1, :], vec(u_pred), label="Model", lw=2, ls=:dash)
The code runs fine when I use Zygote AD (obviously removing the Reactant bits) but I encounter the same error as in the PINN implemetation when running the code using Enzyme AD.
This is the full error message
error: 'stablehlo.transpose' op using value defined outside the region
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_sI7QMw/module_001_UUvS_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:119
ERROR: LoadError: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/iZa5R/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1140
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1135 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(LuxReactantExt.compute_gradients_internal_and_step!), args::Tuple{…}, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, optimize::Bool, cudnn_hlo_optimize::Bool, shardy_passes::Symbol, no_nan::Bool, transpose_propagate::Symbol, reshape_propagate::Symbol, optimize_communications::Bool, assert_nonallocating::Bool, backend::String, raise::Bool, raise_first::Bool, donated_args::Symbol, optimize_then_pad::Bool, runtime::Val{…}, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1511
[5] compile_mlir!
@ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:1319 [inlined]
[6] compile_xla(f::Function, args::Tuple{…}; client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3223
[7] compile_xla
@ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3205 [inlined]
[8] compile(f::Function, args::Tuple{…}; sync::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:3280
[9] macro expansion
@ ~/.julia/packages/Reactant/iZa5R/src/Compiler.jl:2385 [inlined]
[10] (::LuxReactantExt.var"#6#7"{Lux.Training.ReactantBackend{Static.True}, GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}, Tuple{Matrix{…}, Matrix{…}}})()
@ LuxReactantExt ~/.julia/packages/Lux/lRugP/ext/LuxReactantExt/training.jl:107
[11] with(f::LuxReactantExt.var"#6#7"{…}, pair::Pair{…}, rest::Pair{…})
@ Base.ScopedValues ./scopedvalues.jl:269
[12] #with_config#7
@ ~/.julia/packages/Reactant/iZa5R/src/Configuration.jl:62 [inlined]
[13] with_config
@ ~/.julia/packages/Reactant/iZa5R/src/Configuration.jl:34 [inlined]
[14] single_train_step_impl!(backend::Lux.Training.ReactantBackend{…}, objective_function::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ LuxReactantExt ~/.julia/packages/Lux/lRugP/ext/LuxReactantExt/training.jl:103
[15] #single_train_step!#6
@ ~/.julia/packages/Lux/lRugP/src/helpers/training.jl:294 [inlined]
[16] single_train_step!(backend::AutoEnzyme{…}, obj_fn::GenericLossFunction{…}, data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/lRugP/src/helpers/training.jl:288
[17] train_model!(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, opt::Adam{…}, nepochs::Int64)
@ Main ~/***/freezetesting.jl:48
[18] top-level scope
@ ~/***/freezetesting.jl:61
[19] include(fname::String)
@ Main ./sysimg.jl:38
[20] top-level scope
@ REPL[1]:1
in expression starting at ***/freezetesting.jl:61
Some type information was truncated. Use `show(err)` to see complete types.
I do not want to just swap over to Zygote as I am using nested AD with Enzyme in my PINN code, I want to make it work using Enzyme.
Does anyone else have the same issue?