Neural Nets training with multiple Chains Lux.jl and LuxCUDA.jl: LoadError: AssertionError: length(init_params) == length(depvars) when using multi chains in NeuralPDE.jl GPU example

When adapting the GPU example (Using GPUs · NeuralPDE.jl) of NeuralPDE .jl to multiple chains, it does not work.

using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf

const gpud = gpu_device()

@parameters t x y
@variables u(..) v(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = [Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y)),
      Dt(v(t, x, y)) ~ Dxx(v(t, x, y)) + Dyy(v(t, x, y))]

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max),
	
	v(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    v(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    v(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    v(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    v(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
    x ∈ Interval(x_min, x_max),
    y ∈ Interval(y_min, y_max)]

# Neural network
inner = 25
chain1 = Chain(Dense(3, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, 1))
	
chain2 = Chain(Dense(3, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, 1))

strategy = QuasiRandomTraining(100)
ps1 = Lux.setup(Random.default_rng(), chain1)[1]
ps1 = ps1 |> ComponentArray |> gpud .|> Float64
ps2 = Lux.setup(Random.default_rng(), chain2)[1]
ps2 = ps2 |> ComponentArray |> gpud .|> Float64
discretization = PhysicsInformedNN([chain1, chain2],
    strategy,
    init_params = [ps1, ps2])

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y), v(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2500)

It said: “ERROR: LoadError: AssertionError: length(init_params) == length(depvars)”

It seems that I have the same problem as Leo (Neural Nets training with multiple Chains Lux.jl and CUDA.jl), but its solution does not work for me. First, CUDA has been replaced by LuxCUDA in the current version. Second, even after I changed to LuxCUDA and changed ‘gpu’ to ‘gpud’, the scalar error is still reported.

One of the solutions I can think of is to use an older version of neuralPDE.jl and its supporting packages. However, I can only find the Project.toml and manifest files provided after v5.4 (NeuralPDE.jl/v5.4.0/assets/Project.toml at gh-pages · SciML/NeuralPDE.jl · GitHub). However, Project.toml does not have the CUDA version number. Perhaps someone can provide me with a complete older version of project.toml. Thank you!