Hi.
In the following code, I would like to use one neural net with 3 outputs instead of 3 neural nets with 1 output each. I got errors when I try. Is it possible ? All the example I have seen use different nets for different functions. But I think it would be more effective to use one net with more outputs.
What do I need to change to make it possible ?
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimisers, Random
import ModelingToolkit: Interval
import Dates
using GLMakie
@parameters x,y,z
L=10
domains = [x ∈ Interval(-L, +L),
y ∈ Interval(-L, +L),
z ∈ Interval(-L, +L)]
@variables Fx(..),Fy(..),Fz(..)
function G(x,y,z)
a=sqrt(x^2+y^2+(z+0.5)^2)
b=sqrt(x^2+y^2+(z-0.5)^2)
return(2*((z+0.5)/a+(z-0.5)/b)/((a+b)^2+1))
end
Dx = Differential(x)
Dy = Differential(y)
Dz = Differential(z)
eqs=[Dy(Fz(x,y,z))-Dz(Fy(x,y,z))-Dx(G(x,y,z))~0,
Dz(Fx(x,y,z))-Dx(Fz(x,y,z))-Dy(G(x,y,z))~0,
Dx(Fy(x,y,z))-Dy(Fx(x,y,z))-Dz(G(x,y,z))~0]
bcs = [Fx(0,0,0)~0,Fy(0,0,0)~0,Fz(0,0,0)~0]
input_ = length(domains)
n = 12
# chain = Lux.Chain(Dense(input_, n, Lux.asinh), Dense(n, n, Lux.asinh),Dense(n, n, Lux.asinh), Dense(n, 3))
chain = [Lux.Chain(Dense(input_, n, Lux.asinh), Dense(n, n, Lux.asinh),Dense(n, n, Lux.asinh), Dense(n, 1)) for _ in 1:3]
rng=Xoshiro()
ps, st = Lux.setup(rng, chain)
strategy = QuadratureTraining()
discretization = PhysicsInformedNN(chain, strategy)
@named pdesystem = PDESystem(eqs, bcs, domains, [x,y,z], [Fx(x,y,z),Fy(x,y,z),Fz(x,y,z)])
prob = discretize(pdesystem, discretization)
sym_prob = symbolic_discretize(pdesystem, discretization)
pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions
bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions
inter=1000
mod=inter/10
i=0
callback = function (p, l)
if 0==i%mod
print("$i loss: ", l)
print(" pde: ", map(l_ -> l_(p), pde_inner_loss_functions))
println(" bcs: ", map(l_ -> l_(p), bcs_inner_loss_functions))
end
global i+=1
return false
end
res = Optimization.solve(prob, Adam(0.1); callback = callback, maxiters = inter )
minimum = [res.u.depvar[sym_prob.depvars[i]] for i in 1:length(chain)]
d = Dates.now()
xs, ys, zs = [range(infimum(d.domain),supremum(d.domain),length=6) for d in domains]
ns = [Point3f(x,y,z) for x in xs for y in xs for z in zs]
fp(p) = Point3f(first(discretization.phi[1](p[1:3], minimum[1])),first(discretization.phi[2](p[1:3], minimum[2])),first(discretization.phi[3](p[1:3], minimum[3])))
f=arrows(ns,fp)