NeuralPDE with tuple input

,

I want to use the PINN to solve a PDE whose time dependence is mild, but its spatial dependence is complicated. I construct the following neural network,

inner = 8
chain = Chain(Parallel(vcat,
					   Chain(NoOpLayer()),
					   Chain(Dense(1, inner, relu))), Dense(inner + 1, 1, relu))

It has the following form

     t  →  ↘ 
           vcat → layer2 → u
x → layer1 ↗                         

The argument x passes more networks than the argument t.

Here comes to the problem. This NN needs a tuple input. I don’t know how to realize it.

This is a minimum working example:

using Lux, NeuralPDE, Random,Optimization,OptimizationOptimisers
import NeuralPDE.ModelingToolkit: Interval
@parameters t x
@variables u(..)
Dx = Differential(x)
Dt = Differential(t)
t_min = 0.0f0
t_max = 1.0f0
x_min = 0.0f0
x_max = 1.0f0
inner = 8
chain = Chain(Parallel(vcat,
					   Chain(NoOpLayer()),
					   Chain(Dense(1, inner, relu))), Dense(inner + 1, 1, relu))
eq = Dt(u(t, x)) ~ 0f0*Dx(u(t, x))

bcs = [u(t_min, x) ~ 0.0f0]

domains = [t ∈ Interval(t_min, t_max),
	x ∈ Interval(x_min, x_max)]

strategy = QuasiRandomTraining(50)
ps = Lux.setup(Random.default_rng(), chain)[1]
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)
@named pde_system = PDESystem(eq, bcs, domains, [t,x], [u(t, x)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)
callback = function (p, l)
    println("Current loss is: $l")
    return false
end
res = Optimization.solve(prob,OptimizationOptimisers.Adam(0.0001f0);maxiters = 10,callback=callback)

This will give me ERROR: 1 and 2 are not equal.

What’s the error?

Hi, Chris. This is the error

ERROR: 1 and 2 are not equal
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _static_promote
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:326 [inlined]
  [3] static_promote
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:324 [inlined]
  [4] static_promote
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:352 [inlined]
  [5] static_promote
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:363 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:745 [inlined]
  [7] reduce_tup
    @ ~/.julia/packages/Static/1Mvph/src/Static.jl:745 [inlined]
  [8] indices
    @ ~/.julia/packages/StaticArrayInterface/lkDPR/src/ranges.jl:60 [inlined]
  [9] macro expansion
    @ ~/.julia/packages/LoopVectorization/tIJUA/src/condense_loopset.jl:1179 [inlined]
 [10] matmul_loopvec!
    @ ~/.julia/packages/LuxLib/lUGRG/ext/LuxLibLoopVectorizationExt.jl:27 [inlined]
 [11] matmul_cpu!(C::Matrix{…}, ::Static.False, ::Static.False, A::Base.ReshapedArray{…}, B::Matrix{…})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/lUGRG/src/impl/matmul.jl:117
 [12] matmul!
    @ ~/.julia/packages/LuxLib/lUGRG/src/impl/matmul.jl:90 [inlined]
 [13] fused_dense!
    @ ~/.julia/packages/LuxLib/lUGRG/src/impl/dense.jl:30 [inlined]
 [14] fused_dense
    @ ~/.julia/packages/LuxLib/lUGRG/src/impl/dense.jl:24 [inlined]
 [15] rrule
    @ ~/.julia/packages/LuxLib/lUGRG/src/impl/dense.jl:49 [inlined]
 [16] chain_rrule
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:224 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0 [inlined]
 [18] _pullback
    @ ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:87 [inlined]
 [19] fused_dense
    @ ~/.julia/packages/LuxLib/lUGRG/src/impl/dense.jl:11 [inlined]
 [20] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/lUGRG/src/api/dense.jl:35 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(fused_dense_bias_activation), ::typeof(relu), ::Base.ReshapedArray{…}, ::Matrix{…}, ::SubArray{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [22] Dense
    @ ~/.julia/packages/Lux/bRE88/src/layers/basic.jl:343 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::Dense{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [24] apply
    @ ~/.julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [26] applychain
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:0 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [28] Chain
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:480 [inlined]
 [29] _pullback(::Zygote.Context{…}, ::Chain{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [30] apply
    @ ~/.julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [31] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [32] applyparallel
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:0 [inlined]
 [33] _pullback(::Zygote.Context{…}, ::typeof(Lux.applyparallel), ::@NamedTuple{…}, ::typeof(vcat), ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [34] Parallel
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:173 [inlined]
 [35] _pullback(::Zygote.Context{…}, ::Parallel{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [36] apply
    @ ~/.julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [37] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Parallel{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [38] applychain
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:0 [inlined]
 [39] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [40] Chain
    @ ~/.julia/packages/Lux/bRE88/src/layers/containers.jl:480 [inlined]
 [41] _pullback(::Zygote.Context{…}, ::Chain{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [42] apply
    @ ~/.julia/packages/LuxCore/Pl5NJ/src/LuxCore.jl:155 [inlined]
 [43] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Chain{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [44] StatefulLuxLayer
    @ ~/.julia/packages/Lux/bRE88/src/helpers/stateful.jl:119 [inlined]
 [45] _pullback(::Zygote.Context{…}, ::StatefulLuxLayer{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [46] Phi
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/pinn_types.jl:42 [inlined]
 [47] _pullback(::Zygote.Context{…}, ::NeuralPDE.Phi{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [48] #7
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/pinn_types.jl:354 [inlined]
 [49] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#7#8", ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::NeuralPDE.Phi{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [50] numeric_derivative
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/pinn_types.jl:384 [inlined]
 [51] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.numeric_derivative), ::NeuralPDE.Phi{…}, ::NeuralPDE.var"#7#8", ::Matrix{…}, ::Vector{…}, ::Int64, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [52] generated_callfunc
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/discretize.jl:130 [inlined]
 [53] _pullback(::Zygote.Context{…}, ::typeof(RuntimeGeneratedFunctions.generated_callfunc), ::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::NeuralPDE.Phi{…}, ::typeof(NeuralPDE.numeric_derivative), ::NeuralPDE.var"#284#291"{…}, ::NeuralPDE.var"#7#8", ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [54] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:946
 [55] adjoint
    @ ~/.julia/packages/Zygote/nyzjS/src/lib/lib.jl:203 [inlined]
 [56] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [57] RuntimeGeneratedFunction
    @ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150 [inlined]
 [58] _pullback(::Zygote.Context{…}, ::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…}, ::NeuralPDE.Phi{…}, ::typeof(NeuralPDE.numeric_derivative), ::NeuralPDE.var"#284#291"{…}, ::NeuralPDE.var"#7#8", ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [59] #242
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/discretize.jl:150 [inlined]
 [60] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#242#243"{…}, ::Matrix{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [61] #102
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/training_strategies.jl:274 [inlined]
 [62] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#102#105"{…}, args::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [63] #308
    @ ./none:0 [inlined]
 [64] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::NeuralPDE.var"#102#105"{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [65] #666
    @ ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188 [inlined]
 [66] iterate
    @ ./generator.jl:48 [inlined]
 [67] _collect(c::Vector{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:800
 [68] collect_similar
    @ ./array.jl:709 [inlined]
 [69] map
    @ ./abstractarray.jl:3371 [inlined]
 [70] ∇map(cx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:188
 [71] _pullback(cx::Zygote.Context{false}, ::typeof(collect), g::Base.Generator{Vector{…}, NeuralPDE.var"#308#329"{…}})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/lib/array.jl:231
 [72] full_loss_function
    @ ~/.julia/packages/NeuralPDE/BcXJm/src/discretize.jl:462 [inlined]
 [73] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#full_loss_function#328"{…}, ::ComponentArrays.ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
 [74] pullback(::Function, ::Zygote.Context{…}, ::ComponentArrays.ComponentVector{…}, ::Vararg{…})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:90
 [75] pullback(::Function, ::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:88
 [76] withgradient(::Function, ::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:205
 [77] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/gSdHF/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:92 [inlined]
 [78] value_and_gradient!(f::Function, grad::ComponentArrays.ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentArrays.ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/gSdHF/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:105
 [79] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentArrays.ComponentVector{…}, θ::ComponentArrays.ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:53
 [80] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/JgTMl/src/OptimizationOptimisers.jl:101 [inlined]
 [81] macro expansion
    @ ~/.julia/packages/Optimization/cfp9i/src/utils.jl:32 [inlined]
 [82] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/JgTMl/src/OptimizationOptimisers.jl:83
 [83] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/ZyZAV/src/solve.jl:186
 [84] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/ZyZAV/src/solve.jl:94
Some type information was truncated. Use `show(err)` to see complete types.

I think this neural net needs a tuple input. The following neural net works well

chain = Chain(Dense(2, inner, relu),Dense(inner, 1, relu))

What about

chain = Chain(collect, Dense(2, inner, relu),Dense(inner, 1, relu))

Hi, Chris. Thanks for your reply.

I somehow need the following neural network. The variable rho goes through more layers than the variable t.

I realize it as follows

inner = 256
chain = Chain(x -> (x[1:1, :], x[2:2, :]),
			  Parallel(vcat,
					   Chain(NoOpLayer()),
					   Chain(Dense(1, inner, relu))), Dense(inner + 1, inner, relu),
			  Dense(inner, inner, relu),
			  Dense(inner, 1, σ))

The function x -> (x[1:1, :], x[2:2, :]) has extra memory allocations. But x -> (@view x[1:1, :], @view x[2:2, :]) won’t work. I’m not sure if there exists a better solution.

Thanks again for your reply.