Error with Relu activation function when solving problem on NeuralPDE

I tried to solve the ODE problem by PDESystem on NeuralPDE with relu as activation functions.

chain =[Lux.Chain(Dense(1,10,Lux.relu),Dense(10,20,Lux.relu),Dense(20,10,Lux.relu),Dense(10,1)) for _ in 1:12]

@named pde_system = PDESystem(eqs,bcs,domains,[t],dvs)

strategy = NeuralPDE.GridTraining(0.01)

discretization = PhysicsInformedNN(chain, strategy)
sym_prob = NeuralPDE.symbolic_discretize(pde_system, discretization)

pde_loss_functions = sym_prob.loss_functions.pde_loss_functions
bc_loss_functions = sym_prob.loss_functions.bc_loss_functions

callback = function (p, l)
    println("loss: ", l)
    return false
end
loss_functions =  [pde_loss_functions;bc_loss_functions]

function loss_function(θ,p)
    sum(map(l->l(θ) ,loss_functions))
end

f_ = OptimizationFunction(loss_function, Optimization.AutoZygote())
prob = Optimization.OptimizationProblem(f_, sym_prob.flat_init_params);
phi = discretization.phi;

res = Optimization.solve(prob,OptimizationOptimJL.BFGS(); callback = callback, maxiters = 10000)

Although I set the maximum iterations is 10000, the system had run less than 10000 iterations and the loss was still high.

loss: 858278.3955554628
loss: 856463.6937835787
loss: 854564.9613937325
loss: 848744.6992970144
loss: 844192.9726565297
loss: 742945.8872565811
loss: 740183.5543177559
loss: 737515.1168722081
loss: 545432.9718656855
loss: 540181.8503273559
loss: 539055.5701096215
loss: 520512.0307348489
loss: 512659.92624604533
loss: 478501.33506512915
loss: 437132.39857234206
loss: 384331.8369923271
loss: 313843.8857253902
loss: 303354.3593772265
loss: 297767.2273346808
loss: 294820.9831285363
loss: 290904.85949187336
loss: 289918.8168230032
loss: 288182.35507709265
loss: 283039.3822765338
loss: 276946.3874674323
loss: 271870.4964458572
loss: 255643.9181664962
loss: 238663.42878303188
loss: 218013.8331590277
loss: 196454.177072889
loss: 190372.2087389529
loss: 168903.7443233281
loss: 164963.39916406816
loss: 147559.9764464995
loss: 141447.18909037876
loss: 138384.72193233084
loss: 129345.41290518362
loss: 116445.79383935902
loss: 91374.3203737037
loss: 89485.18298122114
loss: 84653.07072666878
loss: 81020.43471064388
loss: 73091.46577748483
loss: 48506.802178812715
loss: 39917.960525598435
loss: 38799.207634892104
loss: 30318.074717297426
loss: 26402.828670269
loss: 26081.22332583516
loss: 25038.16780004235
loss: 23374.226884985626
loss: 18791.497062457056
loss: 17423.433265059004
loss: 16463.749916774726
loss: 15773.215439678199
loss: 15386.079200618415
loss: 14142.035615042794
loss: 12199.585941177424
loss: 8538.832513897847
loss: 8538.717261101248
loss: 7132.266979463222
loss: 5119.1559550463935
loss: 4918.474226803187
loss: 4487.433671475639
loss: 3973.5861161633757
loss: 3554.500411875686
loss: 2979.9664090216
loss: 2274.116148459351
loss: 1931.5014570325222
loss: 1541.463285584013
loss: 1472.313334387744
loss: 1116.4626250624528
loss: 970.3930473445745
loss: 745.9895023271383
loss: 588.6190561775101
loss: 541.4798054831839
loss: 541.479492909954
loss: 533.8799794981358
loss: 504.73026995647626
loss: 434.52439638118864
loss: 413.80488116970133
loss: 413.52624912762997
loss: 382.4362046002455
loss: 351.76119069426295
loss: 339.0367298925745
loss: 322.8163792840978
loss: 316.67460250475386
loss: 313.11541560156996
loss: 298.8601539056951
loss: 291.3923269099095
loss: 270.3685513467391
loss: 270.10416093226144
loss: 255.07360967742693
loss: 241.9438455543283
loss: 226.42604240805713
loss: 224.9809675894178
loss: 217.66072763210852
loss: 206.5653642949857
loss: 197.68570539700738
loss: 192.38674141104124
loss: 189.414817336702
loss: 183.39592378127503
loss: 179.71653683727317
loss: 168.70299533885955
loss: 158.54762963149085
loss: 154.70628723487135
loss: 153.56284795904784
loss: 149.2241104000355
loss: 146.64880703232464
loss: 144.4570618641078
loss: 142.43968486040717
loss: 140.09609669918882
loss: 138.65270580610368
loss: 136.03424013974796
loss: 133.03576148998556
loss: 132.63648941494128
loss: 131.64267730992577
loss: 130.35732809168573
loss: 129.5219113322507
loss: 129.05767855799726
loss: 127.63270743828342
loss: 126.39353560958149
loss: 124.12161739875702
loss: 123.3816622894179
loss: 120.11232883700671
loss: 118.29923021922474
loss: 116.57739346881023
loss: 116.5772528444146
loss: 116.18342018408835
loss: 114.0508020702462
loss: 112.26287666759802
loss: 110.20587695650816
loss: 109.61104450488509
loss: 109.29447872271312
loss: 108.22481026640781
loss: 107.8995956612347
loss: 107.75042696446585
loss: 107.446364683631
loss: 107.09921002378725
loss: 106.59211701280759
loss: 105.738967433186
loss: 105.53584346567021
loss: 104.80824427334345
loss: 104.31261331671128
loss: 103.22199106346315
loss: 102.66291448622157
loss: 100.89788311785944
loss: 98.93879552518318
loss: 97.69268956893542
loss: 97.09874614773695
loss: 97.03951853187024
loss: 95.61592205712046
loss: 94.77695241993669
loss: 93.84068339939971
loss: 93.17189533022913
loss: 92.6718694739493
loss: 92.66549530626511
loss: 92.1595073649036
loss: 91.89175738548845
loss: 90.68799993622896
loss: 90.10243523749644
loss: 89.7492699833171
loss: 88.78650340844007
loss: 87.83362652562846
loss: 87.03035901301268
loss: 86.34233132122309
loss: 85.9721778949065
loss: 85.30509004396055
loss: 84.95913107130357
loss: 83.86464986194882
loss: 83.27223379241507
loss: 82.84592429504629
loss: 81.82793586444959
loss: 80.23909942804964
loss: 79.567292961009
loss: 78.95412021782379
loss: 77.16295318937054
loss: 75.47773258072651
loss: 75.03967759713917
loss: 74.54666365737181
loss: 74.24805406218115
loss: 73.47620874881889
loss: 72.27332185744055
loss: 71.12571404524951
loss: 70.79928827602133
loss: 70.13495201158985
loss: 70.01188113698035
loss: 69.82470708567564
loss: 69.8003809238722
loss: 69.80037973393411
loss: 69.65271944501237
loss: 69.65154428205781
loss: 69.41804902089196
loss: 68.88192379128341
loss: 68.58551009223441
loss: 68.15645504426224
loss: 67.95735404226693
loss: 67.85633640725494
loss: 67.49017313796499
loss: 67.30868597596725
loss: 67.13474437717937
loss: 66.85962061441609
loss: 66.11350278873212
loss: 65.91695977221916
loss: 65.81003320247291
loss: 65.59586065448333
loss: 65.33519432222405
loss: 65.10551796981623
loss: 64.59817370627844
loss: 64.36862345260994
loss: 64.00315535223346
loss: 63.86785452811675
loss: 63.85439494468053

When I tried to remake the problem and continuously solving, it also run for several iterations and stop when the loss was also high.

How could I fix it to run the solving until the desired loss?
Thank you all.

What was the error?

Check the return code res.retcode and see why it stopped. There are several possibilities besides iteration count.

The solver stopped when the iterations were much less than 10000, and the loss values were still high. I think the solving process will be stopped until the loss is small (if given) or reach the maximum iteration.

BFGS from Optim has stopping conditions for sufficiently small gradients. If you check the retcode you should see that.

After checking the return code res.retcode, I got the following: ReturnCode.Failure = 9. How could I fix the code? (In the case of ReturnCode.Default = 0, which means the code was run successfully without failure?)

I have checked the retcode but it just returns the following information ReturnCode.Failure = 9.

Check sol.original

did you mean the Convergence measures?

 * Status: failure (line search failed)

 * Candidate solution
    Final objective value:     6.188523e+01

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 3.96e-05 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.74e-06 ≰ 0.0e+00
    |f(x) - f(x')|         = 2.07e-03 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 3.35e-05 ≰ 0.0e+00
    |g(x)|                 = 1.60e+04 ≰ 1.0e-08

 * Work counters
    Seconds run:   1585  (vs limit Inf)
    Iterations:    240
    f(x) calls:    3295
    ∇f(x) calls:   3295

There you go. You may want to change it to Backtracking.

I am sorry but I am just a newbie, would you please explain more about Backtracking?

BFGS(linesearch=BackTracking())

1 Like

After using the BackTracking() the status now is success:

 * Status: success

 * Candidate solution
    Final objective value:     7.977749e+02

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 6.94e-18 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.88e-19 ≰ 0.0e+00
    |f(x) - f(x')|         = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 0.00e+00 ≤ 0.0e+00
    |g(x)|                 = 1.35e+03 ≰ 1.0e-08

 * Work counters
    Seconds run:   438  (vs limit Inf)
    Iterations:    237
    f(x) calls:    2891
    ∇f(x) calls:   238

However, the results from the NeuralPDE are very different from the results from the ODE solver (Tsit5()). Besides, I do not understand why the status is a success, but the loss value is still very large (nearly 800).

Success in Optim means it found a minima. A minima in a neural network may be a local minima which may not be that good. This is part of the reason why we do the PolyAlg approach that mixes Adam and BFGS.

For the PolyAlg, I have searched the information on the package, but there is a blank page. Does the package release? One person had provided me a code for NNODE with a new training strategy and had imported the PolyAlg package, but I think he did not use it in the code (Issue #708 in NeuralPDE Github)

We need to document it better. It’s on the roadmap.

2 Likes

It is not related, but may I ask about the issue with PINN in NeuralPDE? (Issue with sin and cos functions, which is opened in NeuralPDE Github #710). Does it being fixed?
My current version of packages:

[b2108857] Lux v0.4.58
[961ee093] ModelingToolkit v8.63.0
[315f7962] NeuralPDE v5.7.0
[7f7a1694] Optimization v3.15.2
[36348300] OptimizationOptimJL v0.1.9

Which issue?

Let take an example in NeuralPDE tutorial:

using NeuralPDE, Lux, Optimization, OptimizationOptimJL
import ModelingToolkit: Interval

@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2

# 2D PDE
eq  = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)

# Boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ 0.0,
       u(x,0) ~ 0.0, u(x,1) ~ 0.0]
# Space and time domains
domains = [x ∈ Interval(0.0,1.0),
           y ∈ Interval(0.0,1.0)]

# Neural network
dim = 2 # number of dimensions
chain = Lux.Chain(Dense(dim,16,Lux.σ),Dense(16,16,Lux.σ),Dense(16,1))

# Discretization
dx = 0.05
discretization = PhysicsInformedNN(chain,GridTraining(dx))

@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
prob = discretize(pde_system,discretization)

#Optimizer
opt = OptimizationOptimJL.BFGS()

#Callback function
callback = function (p,l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, opt, callback = callback, maxiters=1000)
phi = discretization.phi

using Plots

xs,ys = [infimum(d.domain):dx/10:supremum(d.domain) for d in domains]
analytic_sol_func(x,y) = (sin(pi*x)*sin(pi*y))/(2pi^2)

u_predict = reshape([first(phi([x,y],res.u)) for x in xs for y in ys],(length(xs),length(ys)))
u_real = reshape([analytic_sol_func(x,y) for x in xs for y in ys], (length(xs),length(ys)))
diff_u = abs.(u_predict .- u_real)

p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic");
p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict");
p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error");
plot(p1,p2,p3)

The error’s message is:

MethodError: no method matching sin(::Matrix{Float64})
You may have intended to import Base.sin

Closest candidates are:
  sin(::ForwardDiff.Dual{T}) where T
   @ ForwardDiff C:\Users\Strawberry\.julia\packages\ForwardDiff\vXysl\src\dual.jl:238
  sin(::DualNumbers.Dual)
   @ DualNumbers C:\Users\Strawberry\.julia\packages\DualNumbers\5knFX\src\dual.jl:327
  sin(::Float64)
   @ NaNMath C:\Users\Strawberry\.julia\packages\NaNMath\ceWIc\src\NaNMath.jl:9
  ...

There is the stacktrace information below, but it is too long to show here.

What exactly is the issue? It looks like the code that you posted runs fine.

Did you run it in a fresh REPL? It’s giving you a warning that sin is redefined which means that there’s something that you did in your script that’s not in your MWE.

1 Like