Issue in ModellingToolkit using Flux.destructure

After updating the packages on my project, I am getting an unexpected error.
The project builds an ODE system composed of multiple systems. The ODE system that is raising the error is

function MLP_controller(nn; n=:MLP)
    @parameters t
    pAll, re = Flux.destructure(nn)
    pOutput, reOutput = Flux.destructure(nn[2:end])
    pInput, reInput = Flux.destructure(nn[1:1])
    input_dims = size(nn.layers[1].weight,2); 
    output_dims = size(nn.layers[2].weight,1);  
    N = size(nn.layers[1].weight,1);
    @parameters psymI[1:length(pInput)]
    @parameters psymO[1:length(pOutput)]
    @variables u[1:input_dims](t)
    @variables o[1:output_dims](t)
    @variables h[1:N](t)
    eqs = vcat(
        scalarize(h .~ reInput(psymI)(u)),
        scalarize(o .~ reOutput(psymO)(h))
    )
    return ODESystem(eqs, t, vcat(u,o,h), vcat(psymI,psymO); name=n, defaults = Dict(vcat(psymI,psymO) .=> vcat(pInput,pOutput)))
end

The error is

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::SymbolicUtils.Term{Real, Base.ImmutableDict{DataType, Any}})
Closest candidates are:
  (::ChainRulesCore.ProjectTo{var"#s12", D} where {var"#s12"<:Number, D<:NamedTuple})(::ChainRulesCore.Tangent{var"#s11", T} where {var"#s11"<:Complex, T}) at /Users/adrianaperezrotondo/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:192
  (::ChainRulesCore.ProjectTo{var"#s12", D} where {var"#s12"<:Number, D<:NamedTuple})(::ChainRulesCore.Tangent{var"#s11", T} where {var"#s11"<:Number, T}) at /Users/adrianaperezrotondo/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:193
  (::ChainRulesCore.ProjectTo{T, D} where D<:NamedTuple)(::ChainRulesCore.Tangent{var"#s12", T} where {var"#s12"<:T, T}) where T at /Users/adrianaperezrotondo/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:143
  ...
Stacktrace:
  [1] _map(::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ::Symbolics.ArrayOp{Vector{Real}})
    @ Symbolics ~/.julia/packages/Symbolics/hgePJ/src/array-lib.jl:271
  [2] macro expansion
    @ ~/.julia/packages/Symbolics/hgePJ/src/array-lib.jl:262 [inlined]
  [3] var"Base.map_4228762831532280247"(f::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, x::Symbolics.ArrayOp{Vector{Real}})
    @ Symbolics ~/.julia/packages/Symbolics/hgePJ/src/wrapper-types.jl:113
  [4] map(f::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, x::Symbolics.Arr{Symbolics.Num, 1})
    @ Symbolics ~/.julia/packages/Symbolics/hgePJ/src/wrapper-types.jl:122
  [5] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Symbolics.Arr{Symbolics.Num, 1})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:236
  [6] _getat(y::Vector{Float64}, o::Int64, flat::Symbolics.Arr{Symbolics.Num, 1})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:90
  [7] (::Optimisers.var"#20#21"{Symbolics.Arr{Symbolics.Num, 1}})(y::Vector{Float64}, o::Int64)
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:85
  [8] fmap(f::Function, x::Vector{Float64}, ys::Int64; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
  [9] #31
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
 [10] (::Optimisers.var"#22#23"{Functors.var"#31#32"{typeof(Optimisers.isnumeric), typeof(Optimisers._trainable_biwalk), IdDict{Any, Any}, Functors.NoKeyword, Optimisers.var"#20#21"{Symbolics.Arr{Symbolics.Num, 1}}}})(c::Vector{Float64}, t::Vector{Float64}, a::Int64)
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:101
 [11] map (repeats 2 times)
    @ ./tuple.jl:252 [inlined]
 [12] map(::Function, ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, typeof(tanh)}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float64}, Vector{Float64}, typeof(tanh)}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}})
    @ Base ./namedtuple.jl:197
 [13] _trainmap
    @ ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:100 [inlined]
 [14] _trainable_biwalk(f::Function, x::Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, aux::NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:96
 [15] fmap(f::Function, x::Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, ys::NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
 [16] #31
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
 [17] #22
    @ ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:101 [inlined]
 [18] map
    @ ./tuple.jl:252 [inlined]
 [19] _trainmap(f::Functors.var"#31#32"{typeof(Optimisers.isnumeric), typeof(Optimisers._trainable_biwalk), IdDict{Any, Any}, Functors.NoKeyword, Optimisers.var"#20#21"{Symbolics.Arr{Symbolics.Num, 1}}}, ch::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, tr::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:100
 [20] _trainable_biwalk(f::Function, x::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:96
 [21] fmap(f::Function, x::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, ys::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
 [22] #31
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
 [23] (::Optimisers.var"#22#23"{Functors.var"#31#32"{typeof(Optimisers.isnumeric), typeof(Optimisers._trainable_biwalk), IdDict{Any, Any}, Functors.NoKeyword, Optimisers.var"#20#21"{Symbolics.Arr{Symbolics.Num, 1}}}})(c::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, t::Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}, a::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:101
 [24] map
    @ ./tuple.jl:252 [inlined]
 [25] map(::Function, ::NamedTuple{(:layers,), Tuple{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}, ::NamedTuple{(:layers,), Tuple{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}, ::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}})
    @ Base ./namedtuple.jl:197
 [26] _trainmap
    @ ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:100 [inlined]
 [27] _trainable_biwalk(f::Function, x::Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, aux::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:96
 [28] fmap(f::Function, x::Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, ys::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
 [29] _rebuild(x::Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, off::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}, flat::Symbolics.Arr{Symbolics.Num, 1}, len::Int64; walk::Function, kw::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:84
 [30] _rebuild
    @ ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:83 [inlined]
 [31] (::Optimisers.Restructure{Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}})(flat::Symbolics.Arr{Symbolics.Num, 1})
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/destructure.jl:51
 [32] MLP_controller(nn::Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Flux.Dense{typeof(identity), Matrix{Float64}, Bool}}}; n::Symbol)
    @ CerebellarMotorLearning ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/src/systemComponents_functions.jl:149
 [33] MLP_controller(nn::Flux.Chain{Tuple{Flux.Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Flux.Dense{typeof(identity), Matrix{Float64}, Bool}}})
    @ CerebellarMotorLearning ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/src/systemComponents_functions.jl:137
 [34] CerebellarMotorLearning.NeuralNetwork(nnDims::Tuple{Int64, Int64, Int64}, Z::Matrix{Float64}, W::Matrix{Float64})
    @ CerebellarMotorLearning ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/src/types.jl:169
 [35] CerebellarMotorLearning.System(plantMatrices::NTuple{4, Matrix{Int64}}, pidGains::NTuple{4, Float64}, nnDims::Tuple{Int64, Int64, Int64}, fc::Float64, trajTime::Float64, lookahead_times::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}, Z::Matrix{Float64}, W::Matrix{Float64})
    @ CerebellarMotorLearning ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/src/types.jl:55
 [36] build_system
    @ ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/src/types.jl:69 [inlined]
 [37] build_systems_sim(Ns::StepRange{Int64, Int64}, plantMatrices::NTuple{4, Matrix{Int64}}, Ks::NTuple{4, Float64}, fc::Float64, trajTime::Float64, lookahead_times::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}, num_nn_inputs::Int64, num_nn_outputs::Int64, K::Int64, randomSeed::Int64)
    @ Main ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/scripts/sizeSim_functions.jl:60
 [38] (::var"#401#402")(r::Int64)
    @ Main ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/scripts/testSize_static_Ls-SS_simulate.jl:68
 [39] iterate
    @ ./generator.jl:47 [inlined]
 [40] _collect(c::Vector{Int64}, itr::Base.Generator{Vector{Int64}, var"#401#402"}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:691
 [41] collect_similar(cont::Vector{Int64}, itr::Base.Generator{Vector{Int64}, var"#401#402"})
    @ Base ./array.jl:606
 [42] map(f::Function, A::Vector{Int64})
    @ Base ./abstractarray.jl:2294
 [43] top-level scope
    @ ~/Dropbox (Cambridge University)/ControlLab/Projects/CerebellumExpansion/Code/Julia/CerebellarMotorLearning/scripts/testSize_static_Ls-SS_simulate.jl:67

After debugging I have narrowed down the issue to the line scalarize(h .~ reInput(psymI)(u)). If I change that line to scalarize(h .~ ones(N)), no error is raised. This is especially frustrating as the line scalarize(o .~ reOutput(psymO)(h)) which is very similar to the previous one works fine.

MWE

using Flux, ModelingToolkit, Plots, DifferentialEquations, LinearAlgebra

function MLP_controller(nn; n=:MLP)
    @parameters t
    pAll, re = Flux.destructure(nn)
    pOutput, reOutput = Flux.destructure(nn[2:end])
    pInput, reInput = Flux.destructure(nn[1:1])
    input_dims = size(nn.layers[1].weight,2); 
    output_dims = size(nn.layers[2].weight,1);  
    N = size(nn.layers[1].weight,1);
    @parameters psymI[1:length(pInput)]
    @parameters psymO[1:length(pOutput)]
    @variables u[1:input_dims](t)
    @variables o[1:output_dims](t)
    @variables h[1:N](t)
    eqs = vcat(
        # o .~ re(vcat(psymI,psymO))(u),
        scalarize(h .~ reInput(psymI)(u)),
        scalarize(o .~ reOutput(psymO)(h))
    )
    return ODESystem(eqs, t, vcat(u,o,h), vcat(psymI,psymO); name=n, defaults = Dict(vcat(psymI,psymO) .=> vcat(pInput,pOutput)))
    # return ODESystem(eqs,t,vcat(u,o,h),vcat(psymI,psymO))
end

N = 20
num_nn_inputs = 10


Z0 = ones(N,num_nn_inputs)
W = ones(N,1)

nn = Chain(
    Dense(Z0, true, tanh),
    Dense(W,false,identity)
)

MLP_controller(nn)

Thanks for your help!

You’ll need to @symbolic_register the function. Though this won’t be so optimized right now. We still need to work on connecting the two.

Hi Chris,

thanks for the response!
I have been looking at the documentation of @register_symbolic but I can’t quite figure out how to symbolic register the re functions as they are defined by the output of Flux.destructure()

Also, do you have any insights into why this was working without raising errors until very recently? Could I fix Modelling toolkit to an earlier version to avoid the errors?

Thanks for your help!

It traced to scalars and I guess that was fine. I don’t think this was an MTK change, instead this seems to be fallout of the Flux v0.13 upgrade. It’s failing in the ChainRulesCore projection stuff that was recently added to the AD libraries. @dhairyagandhi96 do you have an idea why this would fail now?

If you do the registration though, it’s going to be extremely slow generated code without the optimizations it needs.

Thanks so much Chris! I went back to Flux version 0.12.9 and it is working again!

That would be the internal callable struct (::Optimisers.Restructure)(...), which is defined here.

I’m afraid this was probably a fluke. We had to completely overhaul the destructure function’s internals in Flux because the old version was accumulating new, fundamentally breaking, yet unresolvable bugs every few weeks.

As for a fix, I don’t understand enough about Symbolics to know what level of granularity is best for registering new functions. If excluding (::Restructure)(...) and its entire call chain is too much, you could try (::ChainRulesCore.ProjectTo{AbstractArray})(::Symbolics.Arr{Symbolics.Num, 1}) or something similar.