I tried to run the tutorial of retrain by neural_adapter as below:
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, DiffEqBase
import ModelingToolkit: Interval, infimum, supremum
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
# 2D PDE
eq = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)
# Initial and boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
y ∈ Interval(0.0,1.0)]
quadrature_strategy = NeuralPDE.QuadratureTraining(reltol=1e-2,abstol=1e-2,
maxiters =50, batch=100)
inner = 8
af = Lux.tanh
chain1 = Chain(Dense(2,inner,af),
Dense(inner,inner,af),
Dense(inner,1))
discretization = NeuralPDE.PhysicsInformedNN(chain1,
quadrature_strategy)
@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
prob = NeuralPDE.discretize(pde_system,discretization)
sym_prob = NeuralPDE.symbolic_discretize(pde_system,discretization)
res = Optimization.solve(prob, BFGS(); maxiters=2000)
phi = discretization.phi
inner_ = 12
af = Lux.tanh
chain2 = Lux.Chain(Dense(2,inner_,af),
Dense(inner_,inner_,af),
Dense(inner_,inner_,af),
Dense(inner_,1))
init_params2 = Float64.(ComponentArray(Lux.setup(Random.default_rng(), chain)[1]))
# the rule by which the training will take place is described here in loss function
function loss(cord,θ)
chain2(cord,θ) .- phi(cord,res.u)
end
strategy = NeuralPDE.GridTraining(0.02)
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy)
callback = function (p,l)
println("Current loss is: $l")
return false
end
res_ = Optimization.solve(prob_, BFGS();callback = callback, maxiters=1000)
phi_ = NeuralPDE.get_phi(chain2)
xs,ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
analytic_sol_func(x,y) = (sin(pi*x)*sin(pi*y))/(2pi^2)
u_predict = reshape([first(phi([x,y],res.u)) for x in xs for y in ys],(length(xs),length(ys)))
u_predict_ = reshape([first(phi_([x,y],res_.minimizer)) for x in xs for y in ys],(length(xs),length(ys)))
u_real = reshape([analytic_sol_func(x,y) for x in xs for y in ys], (length(xs),length(ys)))
diff_u = u_predict .- u_real
diff_u_ = u_predict_ .- u_real
using Plots
p1 = plot(xs, ys, u_predict, linetype=:contourf,title = "first predict");
p2 = plot(xs, ys, u_predict_,linetype=:contourf,title = "second predict");
p3 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic");
p4 = plot(xs, ys, diff_u,linetype=:contourf,title = "error 1");
p5 = plot(xs, ys, diff_u_,linetype=:contourf,title = "error 2");
plot(p1,p2,p3,p4,p5)
And I got the error for the code res_ = Optimization.solve(prob_, BFGS();callback = callback, maxiters=1000)
this is the solving for the neural_adapter model. The error follows:
MethodError: no method matching (::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(::Matrix{Float64}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8), NamedTuple())), bias = ViewAxis(65:72, ShapedAxis((8, 1), NamedTuple())))), layer_3 = ViewAxis(97:105, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))))}}})
Closest candidates are:
(::Chain)(::Any, ::Any, ::NamedTuple)
@ Lux C:\Users\htran\.julia\packages\Lux\s0bDu\src\layers\containers.jl:456
@ SciMLBase C:\Users\htran\.julia\packages\SciMLBase\XHyFZ\src\solve.jl:83
[35] top-level scope
@ In[131]:57
``` (This is not the full message of error due to the limit of words in a post, please run the code to see the full error).
Please help me on how to fix it.
Thank you all.