I’m trying to create a loss function with a reduced set of adjustable parameters. I’m using LabelledArrays
to keep track of the parameters using symbolic names. What I’d like is to be able to say “hold this set of parameters fixed at the initial-guess values while solving the problem.” Looks like Zygote is having trouble with my implementation. I feel like there’s got be a better way, but I’m not sure what it would be. Maybe just using the original loss function & passing some constraints to the minimizer, or turning this into a ControlSystem
?
using LabelledArrays
using GalacticOptim
using ModelingToolkit
using DiffEqFlux
using Zygote
MT = ModelingToolkit
pars = @parameters β α
@variables t E(t)
D = Differential(t)
eqs = [D(E) ~ α*t + β]
sys = ODESystem(eqs)
u0 = LVector(E = 1.0)
p = LVector(β = 0.8,
α = 0.75)
tspan = (0.0,1.0)
prob = ODEProblem(sys,u0,tspan,p)
sol_true = solve(prob,Tsit5(),saveat=0.1)
function loss(inputs)
N = length(prob.p)
_p = inputs[1:N]
_u0 = inputs[N:end]
tmp_prob = remake(prob,p=_p,u0=_u0)
tmp_sol = solve(tmp_prob,Tsit5(),saveat=0.1)
sum(map(abs2,tmp_sol[E] - sol_true[E]))
end
indexof(sym,syms) = findfirst(isequal(sym),syms)
"Fix the values of certain parameters & return appropriate loss function and initial guess"
function fix_params(loss,p,u0,fixed)
tot_vec = vcat(p,u0)
tot_names = propertynames(tot_vec)
free_names = setdiff(tot_names,Symbol.(fixed))
free_indices = [indexof(x,totnames) for x in free_names]
free_tuples = [(q,tot_vec[q]) for q in free_names]
free_vec = LVector(NamedTuple(free_tuples))
function loss_fixed(par_vec)
input = [hasproperty(par_vec,name) ? (name,par_vec[name]) : (name,totvec[name]) for name in tot_names]
return loss(input)
end
return loss_fixed,free_vec,free_indices
end
loss_w,p0,free_inds = fix_params(loss,p,u0,[β])
loss_w(p0)
Zygote.gradient(loss_w,p0)
Gets an error:
julia> Zygote.gradient(loss_w,p0)
ERROR: ArgumentError: invalid index: E(t) of type Num
Stacktrace:
[1] to_index(i::Num)
@ Base ./indices.jl:300
[2] to_index(A::Matrix{Float64}, i::Num)
@ Base ./indices.jl:277
[3] to_indices
@ ./indices.jl:333 [inlined]
[4] to_indices
@ ./indices.jl:325 [inlined]
[5] view
@ ./subarray.jl:176 [inlined]
[6] (::Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}})(dy::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/array.jl:43
[7] (::Zygote.var"#2248#back#404"{Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[8] Pullback
@ ./REPL[177]:7 [inlined]
[9] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[276]:13 [inlined]
[11] (::typeof(∂(λ)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
[12] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41
[13] gradient(f::Function, args::LArray{Float64, 1, Vector{Float64}, (:α, :E)})
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59
[14] top-level scope
@ REPL[280]:1
[Edit: wrong version of MWE]