I want to build an Equation using a neural network (based on Lux.jl) as follows:
using Lux
using LuxCore
using Symbolics
using StableRNGs
using ModelingToolkit
using ModelingToolkitNeuralNets
using ModelingToolkit:t_nounits as t
using ComponentArrays
using ModelingToolkitStandardLibrary.Blocks: RealInputArray
nn = Lux.Chain(
Lux.Dense(3 => 16), # , Lux.tanh
Lux.Dense(16 => 16), # , Lux.leakyrelu
Lux.Dense(16 => 2)#, Lux.leakyrelu
init_params = Lux.initialparameters(StableRNG(42), nn)
@parameters p[1:length(init_params)] = Vector(ComponentVector(init_params))
@parameters ptype::typeof(typeof(init_params)) = typeof(init_params) [tunable = false]
lazyconvert_p = Symbolics.array_term(convert, ptype, p, size=size(p))
@variables v1(t), v2(t), v3(t)
exprs = LuxCore.stateless_apply(nn, [v1, v2, v3], lazyconvert_p)[1]
This code is written according to the ModelingToolkitNeuralNets.jl reference, and then I want to build a RuntimeGeneratedFunction based on this exprs, the result is as follows:
temp_func = build_function(exprs, [v1, v2, v3], [p, ptype], expression=Val{false})
>RuntimeGeneratedFunction(#=in Symbolics=#, #=using Symbolics=#, :((ˍ₋arg1, ˍ₋arg2)->begin
#= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:373 =#
#= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:374 =#
#= D:\Julia\Julia-1.10.0\packages\packages\SymbolicUtils\c0xQb\src\code.jl:375 =#
(getindex)((LuxCore.stateless_apply)(Chain(), Num[v1(t), v2(t), v3(t)], (convert)(ˍ₋arg2[2], ˍ₋arg2[1])), 1)
As can be seen from the results, the input variables [v1, v2, v3] are not replaced by the parameter arg.
In addition, when this function is called with actual data, parameter type conversion cannot convert model parameters through boardcast:
temp_func([1,2,3], [ComponentVector(init_params), eltype(ComponentVector(init_params))])
>ERROR: MethodError: Cannot `convert` an object of type ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:64, Axis(weight = ViewAxis(1:48, ShapedAxis((16, 3))), bias = ViewAxis(49:64, ShapedAxis((16, 1))))), layer_2 = ViewAxis(65:98, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = ViewAxis(33:34, ShapedAxis((2, 1))))))}}} to an object of type Float32