When adapting the GPU example (Using GPUs · NeuralPDE.jl) of NeuralPDE .jl to multiple chains, it does not work.
using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf
const gpud = gpu_device()
@parameters t x y
@variables u(..) v(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0
# 2D PDE
eq = [Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y)),
Dt(v(t, x, y)) ~ Dxx(v(t, x, y)) + Dyy(v(t, x, y))]
analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
u(t, x, y_max) ~ analytic_sol_func(t, x, y_max),
v(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
v(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
v(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
v(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
v(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]
# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
x ∈ Interval(x_min, x_max),
y ∈ Interval(y_min, y_max)]
# Neural network
inner = 25
chain1 = Chain(Dense(3, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, 1))
chain2 = Chain(Dense(3, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, inner, Lux.σ),
Dense(inner, 1))
strategy = QuasiRandomTraining(100)
ps1 = Lux.setup(Random.default_rng(), chain1)[1]
ps1 = ps1 |> ComponentArray |> gpud .|> Float64
ps2 = Lux.setup(Random.default_rng(), chain2)[1]
ps2 = ps2 |> ComponentArray |> gpud .|> Float64
discretization = PhysicsInformedNN([chain1, chain2],
strategy,
init_params = [ps1, ps2])
@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y), v(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)
callback = function (p, l)
println("Current loss is: $l")
return false
end
res = Optimization.solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2500)
It said: “ERROR: LoadError: AssertionError: length(init_params) == length(depvars)”