I’m encountering an issue with Julia broadcasting when working with SubArrays in complex nested expressions. When using @views
to create SubArrays and then applying broadcasting with @.
to complex expressions, I get type mismatch errors.
Minimal Reproducible Example
using SymbolicUtils.Code
using SymbolicUtils
using Symbolics
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
@variables a b c d
@variables p1 p2 p3
function build_f1(input_names, param_names, outputs, exprs)
def_calls1 = [:(@views $nm = i[$idx]) for (idx, nm) in enumerate(input_names)]
def_calls2 = [:(@views $nm = p[$idx]) for (idx, nm) in enumerate(param_names)]
compute_calls = [:($nm = @. $expr) for (nm, expr) in zip(outputs, exprs)]
return_calls = :(return [$(outputs...)])
return :(function (i, p)
$(def_calls1...)
$(def_calls2...)
$(compute_calls...)
$(return_calls)
end)
end
input_names = [:a, :b]
param_names = [:p1, :p2, :p3]
outputs = [:c, :d]
exprs = toexpr.([a * p1 + b * p2, sin(c * p1) + tanh(b * p3)])
f1_expr = build_f1(input_names, param_names, outputs, exprs)
bf1 = @RuntimeGeneratedFunction(f1_expr)
Print the Generated Function
RuntimeGeneratedFunction(#=in Main=#, #=using Main=#, :((i, p)->begin
#= e:\JlCode\HydroModels\dev\bug_submit.jl:18 =#
#= e:\JlCode\HydroModels\dev\bug_submit.jl:19 =#
#= e:\JlCode\HydroModels\dev\bug_submit.jl:14 =# @views a = i[1]
#= e:\JlCode\HydroModels\dev\bug_submit.jl:14 =# @views b = i[2]
#= e:\JlCode\HydroModels\dev\bug_submit.jl:20 =#
#= e:\JlCode\HydroModels\dev\bug_submit.jl:15 =# @views p1 = p[1]
#= e:\JlCode\HydroModels\dev\bug_submit.jl:15 =# @views p2 = p[2]
#= e:\JlCode\HydroModels\dev\bug_submit.jl:15 =# @views p3 = p[3]
#= e:\JlCode\HydroModels\dev\bug_submit.jl:21 =#
c = #= e:\JlCode\HydroModels\dev\bug_submit.jl:16 =# @__dot__((+)((*)(a, p1), (*)(b, p2)))
d = #= e:\JlCode\HydroModels\dev\bug_submit.jl:16 =# @__dot__((+)((tanh)((*)(b, p3)), (sin)((*)(c, p1))))
#= e:\JlCode\HydroModels\dev\bug_submit.jl:22 =#
return [c, d]
end))
Error Message
When I try to use this function with SubArrays, I get an error like:
ERROR: MethodError: no method matching tanh(::Vector{Float64})
The function `tanh` exists, but no method is defined for this combination of argument types.
Observations
- The code works fine with simple expressions but fails with complex nested expressions.
- Direct use of broadcasting on the expressions works
@views input1 = eachslice(input, dims=1)[1]
@views input2 = eachslice(input, dims=1)[2]
eval(:(@__dot__((+)((tanh)((*)(input1, 2)), (sin)((*)(input2, 3))))))
Question
How can I properly apply broadcasting to complex expressions involving SubArrays in generated code? Is there a way to ensure that all operations in nested expressions are properly broadcasted?