Julia crashes for larger PINNs?

Hello, I am trying to train a PINN using NeuralPDE. I use the example of the documentation (Introduction to NeuralPDE for PDEs · NeuralPDE.jl) and when I try to increase the network size, the performance struggles and Julia even crashes. Is there any problem with larger PINNs in NeuralPDE?

The architecture I want to implement is the following (all layers are Dense):
Input dim: 3
Num hidden layers: 3
Output dim: 4
Hidden layer size: 500

For hidden layer size = 100, the performance is much worse than PyTorch, and for size = 200, Julia crashes when I try to train with no errors displayed.

Thanks in advance.

1 Like

Are you running this from a terminal? That should show more information unless your computer is going OOM. Any details on your setup?

Thanks Chris. Yes I am running from terminal on macos. My pc for sure is not out of memory, I monitor the ram live. Running on MacBook pro m3 pro with 18gb ram.

Maybe there is any restriction on dedicated memory that Julia can use?

Your terminal exits as well?

It shouldn’t get to 18gb, but that depends on the sampling strategy. Can you share the code you’re running?

The terminal is killed at all.

It happens for the exact same code as it is in the documentation Introduction to NeuralPDE for PDEs · NeuralPDE.jl just changing network architecture to:

dim = 2
size = 200
chain = Lux.Chain(Dense(dim, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, 1))

As you can see, the sampling strategy is GridTraining(0.05) - I do not think it should be the problem. It also happens in other attempt with QuasiRandomTraining.

It is indeed going out of memory. I tried with LBFGS with:

dim = 2
size = 200
chain = Lux.Chain(Dense(dim, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, size, Lux.σ), Dense(size, 1))

Doing @time for the solve (2nd time):

julia> @time res = Optimization.solve(prob, opt, callback = callback, maxiters = 10)
Current loss is: 6.683189295799788
Current loss is: 3.1048545916265975
Current loss is: 0.2267500861331974
Current loss is: 0.2267498694045826
Current loss is: 0.22667766241472545
Current loss is: 0.22665704654163812
Current loss is: 0.22663500928990737
Current loss is: 0.22660555057974188
Current loss is: 0.22656138117556937
Current loss is: 0.22637768121431287
Current loss is: 0.2259676898608068
127.172276 seconds (258.85 M allocations: 234.668 GiB, 17.15% gc time)
retcode: Failure
u: ComponentVector{Float64}(layer_1 = (weight = [0.16307427793188822 0.15625658880583795; -0.004732452370397745 0.09710655730283399; … ; -0.15517306889593407 -0.008747806390736995; -0.01322809045456459 -0.07411391779699049], bias = [0.025185533054337572; 0.007461577325928262; … ; -0.016892476827341296; -0.003241416169991354;;]), layer_2 = (weight = [0.11177680410774743 -0.0918048914721283 … -0.014141108005049039 -0.10100793225287882; 0.027110037101477434 -0.03187748836764777 … -0.08927054656619957 0.10174297055412043; … ; -0.05724745503950367 0.02108687659043006 … -0.11359340416089839 0.0952017547059512; 0.05137302379590572 -0.026006053489993596 … 0.0670188883032111 -0.046550459855953484], bias = [0.0001301556381884588; 0.00598757464764547; … ; 0.002319653146856979; 0.0004080263747713927;;]), layer_3 = (weight = [0.10640035502618432 0.10340110317492913 … 0.020890185065004317 -0.12035398640546481; -0.08770679373824165 0.10240709520834215 … -0.08300252849231268 0.0547737057037142; … ; 0.062221109664531364 0.01063019520511943 … -0.08095690861229035 -0.07113803350797965; 0.04529870934282639 -0.04278160418800191 … 0.0035466719685386884 0.09096613579544281], bias = [-0.00018283332593967515; -0.0019018446786608603; … ; -0.0016334418549810723; -4.894159180550134e-7;;]), layer_4 = (weight = [0.10808998188043394 -0.043216541421113655 … -0.12302522171686488 -0.05148983732554242; -0.022591070338873722 -0.038620657796605326 … -0.011153600604047579 0.0046260131310429855; … ; 0.07766561004466804 -0.07421688050623126 … -0.053538025197706965 -0.08614639292249386; -0.11811868294719909 -0.09991723998755593 … -0.1185594574353186 -0.021999823677533104], bias = [0.0007397206294582508; -0.0009908259383677195; … ; 0.0009206312586290302; 0.0009211772132128078;;]), layer_5 = (weight = [0.11829070222142016 -0.05592343563791806 … 0.1262594015995021 0.07405717152760742], bias = [0.02107732385275324;;]))

It is allocating ~235GB.

Using Adam does not allocate that much:

julia> @time res = Optimization.solve(prob, opt, callback = callback, maxiters = 10)
Current loss is: 0.2860917022317445
Current loss is: 1.234617242825595
Current loss is: 0.3182132576728858
Current loss is: 0.41630261071783803
Current loss is: 0.7456721600143762
Current loss is: 0.5565277739882979
Current loss is: 0.27479753975045523
Current loss is: 0.25557875923806944
Current loss is: 0.42141515377168065
Current loss is: 0.4900158951016975
Current loss is: 0.25557875923806944
 11.135609 seconds (74.97 M allocations: 4.395 GiB, 21.08% gc time, 12.80% compilation time)
retcode: Default
u: ComponentVector{Float64}(layer_1 = (weight = [-0.017126087360129345 -0.09950323514387924; -0.01668894473333998 -0.1711759554153898; … ; 0.06495176259666341 -0.007905871818531104; 0.17173032901916585 -0.16881519728491048], bias = [0.00014874152991890187; 1.0067188527172223e-5; … ; -0.0004325924902529166; 0.0004040320174820873;;]), layer_2 = (weight = [0.028316474600908068 0.08255876051702847 … 0.03135000735238347 -0.04444160575241615; 0.03640751972919155 0.07750772925637393 … -0.011311141098601886 0.07791238018148845; … ; 0.05208708784431761 -0.11726502014816484 … 0.039481469796620705 0.059643986613829; -0.09987618211252872 -0.04913035328075708 … 0.0381908663320885 -0.07312186114394491], bias = [0.00031591848762609406; -0.00027364305539604675; … ; -0.0010004771779222625; -0.001259332261635275;;]), layer_3 = (weight = [-0.07440508016975876 0.06999317747565037 … 0.08227497776780567 -0.10963093275696452; -0.10010097483013762 0.0714791809269551 … 0.03885212593888927 0.10161415442679306; … ; 0.049009877031242356 0.0879132976049354 … -0.010864286103423585 0.012135223185550586; -0.07696560986258962 0.05342967529850782 … -0.0604492134553235 0.0009869506077696522], bias = [-0.001423979216000808; -0.0024177462754570484; … ; -0.000798799900723637; 4.395984297853381e-5;;]), layer_4 = (weight = [0.09853829934143915 0.10943196611633882 … 0.007966299281793937 -0.10053939931173753; 0.12098132039971743 0.022736636684751125 … 0.06449369655316996 -0.11685271448620674; … ; 0.003740479397246388 -0.0320462985480092 … 0.04577196598745555 -0.07550104544598418; -0.10115814171039753 0.032378965184045996 … -0.11143487785993128 -0.046648813598238976], bias = [0.00025704002465073994; -0.0005391071101855953; … ; -0.00018129965297292416; 0.00032167707097030914;;]), layer_5 = (weight = [0.03013573472803205 -0.11118695582935643 … -0.13702784853293276 0.09655725087631234], bias = [0.0003434485619647583;;]))

So, I would suggest not using BFGS or LBFGS with large networks.

1 Like