My end goal is to train a UDE which is solved via MethodOfLines. For now, my minimal example tries to learn a single parameter \alpha (the initial condition), but gets a warning I don’t understand:
Warning: : no method matching get_unit for arguments (Pair{Num, Float64},).
and then another warning
Warning: setup found no trainable parameters in this model
followed by the error
ERROR: MethodError: no method matching size(::IRTools.Inner.Undefined)
(full stacktrace at the end)
Here is my code. I suspect there is more than one problem with this setup. What would be the correct way of doing it?
using MethodOfLines, ModelingToolkit, DomainSets, OrdinaryDiffEq
using Optimization, ComponentArrays, OptimizationOptimisers, Zygote
using Statistics
x_dim = 50
t_dim = 100
#
# Define PDE
#
# Parameters, variables, and derivatives
@parameters t, x
@parameters α
@variables u(..)
Dt = Differential(t)
Dxx = Differential(x)^2
eq = Dt(u(t, x)) ~ 1.0e-4 * Dxx(u(t, x))
domain = [x ∈ Interval(0.0, 1.0),
t ∈ Interval(0.0, 500.0)]
ic_bc = [u(0.0, x) ~ α,
u(t, 0.0) ~ 1.0,
u(t, 1.0) ~ 0]
@named sys = PDESystem(eq, ic_bc, domain, [t, x], [u(t, x)], [α .=> 0.5])
discretization = MOLFiniteDifference([x => 1.0 / x_dim], t)
prob = discretize(sys, discretization)
#
# Learn α
#
function predict(θ)
_prob = remake(prob, p = [α .=> θ])
return solve(_prob, Tsit5(), saveat = 500.0 / t_dim)
end
function loss(θ)
sol = predict(θ)
return mean(abs2, 0.5 .- Array(sol[u(t,x)]))
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, 0.5)
res = Optimization.solve(optprob, OptimizationOptimisers.ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
Stacktrace:
[1] axes(A::IRTools.Inner.Undefined)
@ Base ./abstractarray.jl:98
[2] _tryaxes(x::IRTools.Inner.Undefined)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:188
[3] map
@ ./tuple.jl:274 [inlined]
[4] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/array.jl:322 [inlined]
[5] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[6] _pullback
@ ./iterators.jl:370 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(zip), ::IRTools.Inner.Undefined, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[8] _pullback
@ ~/.julia/packages/ModelingToolkit/8ZXtB/src/utils.jl:659 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::IRTools.Inner.Undefined)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[10] _pullback
@ ~/.julia/packages/ModelingToolkit/8ZXtB/src/variables.jl:149 [inlined]
[11] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.process_p_u0_symbolic), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, ::Vector{Pair{Num, Float64}}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[12] _pullback
@ ~/.julia/packages/SciMLBase/KcGs1/src/remake.jl:78 [inlined]
[13] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#575", ::Missing, ::Missing, ::Missing, ::Vector{Pair{Num, Float64}}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[14] _pullback
@ ~/.julia/packages/SciMLBase/KcGs1/src/remake.jl:52 [inlined]
[15] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:p,), Tuple{Vector{Pair{Num, Float64}}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#520"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xa4f18dd0, 0x9699d144, 0x553ba557, 0xd7c31365, 0x19957603), Expr}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabe41c3, 0x0ba0eac4, 0x035838e9, 0x4644fed0, 0x4607be1a), Expr}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#565#generated_observed#528"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[16] _pullback
@ ~/myUDE/src/mwe.jl:39 [inlined]
[17] _pullback(ctx::Zygote.Context{false}, f::typeof(predict), args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[18] _pullback
@ ~/myUDE/src/mwe.jl:44 [inlined]
[19] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[20] _pullback
@ ~/myUDE/src/mwe.jl:59 [inlined]
[21] _pullback(::Zygote.Context{false}, ::var"#29#30", ::Float64, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[22] _apply
@ ./boot.jl:838 [inlined]
[23] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
[24] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[25] _pullback
@ ~/.julia/packages/SciMLBase/KcGs1/src/scimlfunctions.jl:3626 [inlined]
[26] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Float64, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[27] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[28] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
[29] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[30] _pullback
@ ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:31 [inlined]
[31] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[32] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[33] adjoint
@ ~/.julia/packages/Zygote/JeHtr/src/lib/lib.jl:203 [inlined]
[34] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[35] _pullback
@ ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:35 [inlined]
[36] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#263#272"{Tuple{}, Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[37] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:44
[38] pullback
@ ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:42 [inlined]
[39] gradient(f::Function, args::Float64)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:96
[40] (::Optimization.var"#262#271"{Optimization.var"#261#270"{OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::Float64, ::Float64)
@ Optimization ~/.julia/packages/Optimization/vFala/src/function/zygote.jl:33
[41] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
[42] macro expansion
@ ~/.julia/packages/Optimization/vFala/src/utils.jl:37 [inlined]
[43] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, var"#29#30", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Float64, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Adam{Float32}, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
[44] __solve (repeats 2 times)
@ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
[45] #solve#553
@ ~/.julia/packages/SciMLBase/KcGs1/src/solve.jl:86 [inlined]
[46] top-level scope
@ ~/myUDE/src/mwe.jl:62