Loss function with symbolic selection of variable parameters

I’m trying to create a loss function with a reduced set of adjustable parameters. I’m using LabelledArrays to keep track of the parameters using symbolic names. What I’d like is to be able to say “hold this set of parameters fixed at the initial-guess values while solving the problem.” Looks like Zygote is having trouble with my implementation. I feel like there’s got be a better way, but I’m not sure what it would be. Maybe just using the original loss function & passing some constraints to the minimizer, or turning this into a ControlSystem?

using LabelledArrays
using GalacticOptim
using ModelingToolkit
using DiffEqFlux
using Zygote
MT = ModelingToolkit

pars = @parameters  β α
@variables t E(t)
D = Differential(t)

eqs = [D(E) ~ α*t + β]
sys = ODESystem(eqs)

u0 = LVector(E = 1.0)
p = LVector(β = 0.8,
            α = 0.75)
tspan = (0.0,1.0)
prob = ODEProblem(sys,u0,tspan,p)
sol_true = solve(prob,Tsit5(),saveat=0.1)

function loss(inputs)
    N = length(prob.p)
    _p = inputs[1:N]
    _u0 = inputs[N:end]
    tmp_prob = remake(prob,p=_p,u0=_u0)
    tmp_sol = solve(tmp_prob,Tsit5(),saveat=0.1)
    sum(map(abs2,tmp_sol[E] - sol_true[E]))
end

indexof(sym,syms) = findfirst(isequal(sym),syms)

"Fix the values of certain parameters & return appropriate loss function and initial guess"
function fix_params(loss,p,u0,fixed)
    tot_vec = vcat(p,u0)
    tot_names = propertynames(tot_vec)
    free_names = setdiff(tot_names,Symbol.(fixed))
    free_indices = [indexof(x,totnames) for x in free_names]
    free_tuples = [(q,tot_vec[q]) for q in free_names]
    free_vec = LVector(NamedTuple(free_tuples))
    function loss_fixed(par_vec)
        input = [hasproperty(par_vec,name) ? (name,par_vec[name]) : (name,totvec[name]) for name in tot_names]
        return loss(input)
    end
    return loss_fixed,free_vec,free_indices
end

loss_w,p0,free_inds = fix_params(loss,p,u0,[β])
loss_w(p0)
Zygote.gradient(loss_w,p0)

Gets an error:

julia> Zygote.gradient(loss_w,p0)
ERROR: ArgumentError: invalid index: E(t) of type Num
Stacktrace:
  [1] to_index(i::Num)
    @ Base ./indices.jl:300
  [2] to_index(A::Matrix{Float64}, i::Num)
    @ Base ./indices.jl:277
  [3] to_indices
    @ ./indices.jl:333 [inlined]
  [4] to_indices
    @ ./indices.jl:325 [inlined]
  [5] view
    @ ./subarray.jl:176 [inlined]
  [6] (::Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}})(dy::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/lib/array.jl:43
  [7] (::Zygote.var"#2248#back#404"{Zygote.var"#408#410"{2, Float64, ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, ModelingToolkit.var"#f#152"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xbabc66d3, 0x633a919b, 0xe4683138, 0x0e254e2d, 0x36925fe8)}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7eb6766a, 0x02a244e1, 0x54678655, 0x9f7f1370, 0x31962eb1)}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, ModelingToolkit.var"#161#generated_observed#159"{Bool, ODESystem, Dict{Any, Any}}, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats}, Tuple{Num}}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [8] Pullback
    @ ./REPL[177]:7 [inlined]
  [9] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[276]:13 [inlined]
 [11] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::LArray{Float64, 1, Vector{Float64}, (:α, :E)})
    @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:59
 [14] top-level scope
    @ REPL[280]:1

[Edit: wrong version of MWE]

I’m going to create a SciMLParameters system for this kind of thing. I need an open day though, so I’ll just wait on the response here.

The getindex with Num needs an adjoint if you want to differentiate that, since it’s an odd indexing implementation.

1 Like

Have you tried ParameterHandling.jl? If you wrap a parameter with fixed, it will not be differentiated/optimized.