Intended way to learn parameters in a PDE system?

My end goal is to train a UDE which is solved via MethodOfLines. For now, my minimal example tries to learn a single parameter \alpha (the initial condition), but gets a warning I don’t understand:

Warning: : no method matching get_unit for arguments (Pair{Num, Float64},).

and then another warning

Warning: setup found no trainable parameters in this model

followed by the error

ERROR: MethodError: no method matching size(::IRTools.Inner.Undefined)

(full stacktrace at the end)

Here is my code. I suspect there is more than one problem with this setup. What would be the correct way of doing it?

using MethodOfLines, ModelingToolkit, DomainSets, OrdinaryDiffEq
using Optimization, ComponentArrays, OptimizationOptimisers, Zygote
using Statistics

x_dim = 50
t_dim = 100

#
# Define PDE
#

# Parameters, variables, and derivatives
@parameters t, x
@parameters α
@variables u(..)
Dt = Differential(t)
Dxx = Differential(x)^2

eq = Dt(u(t, x)) ~ 1.0e-4 * Dxx(u(t, x))

domain = [x ∈ Interval(0.0, 1.0),
          t ∈ Interval(0.0, 500.0)]

ic_bc = [u(0.0, x) ~ α,
         u(t, 0.0) ~ 1.0,
         u(t, 1.0) ~ 0]

@named sys = PDESystem(eq, ic_bc, domain, [t, x], [u(t, x)], [α .=> 0.5])

discretization = MOLFiniteDifference([x => 1.0 / x_dim], t)

prob = discretize(sys, discretization)

#
# Learn α
#

function predict(θ)
    _prob = remake(prob, p = [α .=> θ])
    return solve(_prob, Tsit5(), saveat = 500.0 / t_dim)
end

function loss(θ)
    sol = predict(θ)
    return mean(abs2, 0.5 .- Array(sol[u(t,x)]))
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, 0.5)

res = Optimization.solve(optprob, OptimizationOptimisers.ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

Stacktrace:
  [1] axes(A::IRTools.Inner.Undefined)
    @ Base ./abstractarray.jl:98
  [2] _tryaxes(x::IRTools.Inner.Undefined)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:188
  [3] map
    @ ./tuple.jl:274 [inlined]
  [4] adjoint
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:322 [inlined]
  [5] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
  [6] _pullback
    @ ./iterators.jl:370 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(zip), ::IRTools.Inner.Undefined, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/ModelingToolkit/8ZXtB/src/utils.jl:659 [inlined]
  [9] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::IRTools.Inner.Undefined)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/ModelingToolkit/8ZXtB/src/variables.jl:149 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.process_p_u0_symbolic), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, ::Vector{Pair{Num, Float64}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [12] _pullback
    @ ~/.julia/packages/SciMLBase/KcGs1/src/remake.jl:78 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#575", ::Missing, ::Missing, ::Missing, ::Vector{Pair{Num, Float64}}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [14] _pullback
    @ ~/.julia/packages/SciMLBase/KcGs1/src/remake.jl:52 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:p,), Tuple{Vector{Pair{Num, Float64}}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [16] _pullback
    @ ~/myUDE/src/mwe.jl:39 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::typeof(predict), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [18] _pullback
    @ ~/myUDE/src/mwe.jl:44 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [20] _pullback
    @ ~/myUDE/src/mwe.jl:59 [inlined]
 [21] _pullback(::Zygote.Context{false}, ::var"#29#30", ::Float64, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [22] _apply
    @ ./boot.jl:838 [inlined]
 [23] adjoint
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
 [24] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [25] _pullback
    @ ~/.julia/packages/SciMLBase/KcGs1/src/scimlfunctions.jl:3626 [inlined]
 [26] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Float64, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [27] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [28] adjoint
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
 [29] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [30] _pullback
    @ ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:31 [inlined]
 [31] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [32] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [33] adjoint
    @ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
 [34] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [35] _pullback
    @ ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:35 [inlined]
 [36] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#263#272"{Tuple{}, Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
 [37] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:44
 [38] pullback
    @ ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:42 [inlined]
 [39] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:96
 [40] (::Optimization.var"#262#271"{Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Float64, ::Float64)
    @ Optimization ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:33
 [41] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
 [42] macro expansion
    @ ~/.julia/packages/Optimization/vFala/src/utils.jl:37 [inlined]
 [43] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Float64, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Adam{Float32}, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
 [44] __solve (repeats 2 times)
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
 [45] #solve#553
    @ ~/.julia/packages/SciMLBase/KcGs1/src/solve.jl:86 [inlined]
 [46] top-level scope
    @ ~/myUDE/src/mwe.jl:62

Can you clarify what you want the loss function to be here? The one you have currently gives a fixed value of 0.05254691992173411 regardless of the theta parameter’s value.

In reality the loss function is

function loss(θ)
    sol = predict(θ)
    return mean(abs2, data .- Array(sol[u(t,x)]))
end

where data comes from a file, so instead I just picked some arbitrary loss function for MWE purposes (how close we are to u = 0.5). I don’t think this is the issue though: remake is not assigning α to θ as I desire.

I think the constructor might have to look like this:

@named sys = PDESystem(eq, ic_bc, domain, [t, x], [u(t, x)], [α], defaults = Dict([α .=> 0.5]))

but this causes

discretize(sys, discretization)

to throw ERROR: type Num has no field first

That looks like a bug. @xtalax take a look at that?

I’m not certain what’s causing this, but as a hunch, try not broadcasting your alpha pairs like α => θ

Thanks for the reply. I have the same hunch, but I don’t know another way of doing it, hence my question.

I tried some random things like p=[θ] but none worked.

How is the p argument of remake treated? I can’t even tell if it is supposed to be list or a dictionary.

I mean, do something like [α[i] => p[i] for i in eachindex(α)] or just directly [α => 0.5]

EDIT: Yep, that’s the problem, where you have vector valued α doing

    _prob = remake(prob, p = [α .=> θ])

will give p a vector of vectors of pairs, not a vector of pairs as is needed. Please also note that if you made α with @variables α[1:n] you will need to collect the broadcast before passing it to remake.

Sorry about the late reply

1 Like

What do you mean by “collect the broadcast before passing it to remake”? I tried using @variables instead of @parameters in my problem (for a vector parameter) and got an error:

ERROR: MethodError: no method matching hasmetadata(::Vector{Num}, ::Type{Symbolics.VariableDefaultValue})````

do
_prob = remake(prob, p = collect(α .=> θ))

If you get this working, please submit an example to the docs if you have the time, or otherwise share the working code if possible, this would be good to have