Hi, I would like to define the following parametric function type
struct myParamFun
fun::Function
param
end
(pf::myParamFun)(x...)=pf.fun(x...,pf.param)
which works fine. However, when broadcasting over CuArray
I get the error:
julia> pf=myParamFun((x,y,z,v)->x+y-z*v,2);
julia> xs=(rand(100) for _ in 1:3);
julia> using CUDA
julia> pf.(xs...);
julia> pf.(CuArray.(xs)...);
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
However, the following works without any errors
julia> pf.fun.(CuArray.(xs)...,Ref(pf.param));
The @code_warntype
output of the calls are as follows:
julia> @code_warntype pf.(CuArray.(xs)...);
MethodInstance for (::var"##dotfunction#1253#40")(::Vector{CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}})
from (::var"##dotfunction#1253#40")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#1253#40"())
x1::Vector{CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}
Body::Any
1 β %1 = Core.tuple(Main.pf)::Tuple{Any}
β %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x1)::Any
β %3 = Base.materialize(%2)::Any
βββ return %3
julia> @code_warntype pf.(xs...);
MethodInstance for (::var"##dotfunction#1254#41")(::Base.Generator{UnitRange{Int64}, var"#38#39"})
from (::var"##dotfunction#1254#41")(x1) in Main
Arguments
#self#::Core.Const(var"##dotfunction#1254#41"())
x1::Base.Generator{UnitRange{Int64}, var"#38#39"}
Body::Any
1 β %1 = Core.tuple(Main.pf)::Tuple{Any}
β %2 = Core._apply_iterate(Base.iterate, Base.broadcasted, %1, x1)::Any
β %3 = Base.materialize(%2)::Any
βββ return %3
julia> @code_warntype pf.fun.(CuArray.(xs)...,Ref(pf.param));
MethodInstance for (::var"##dotfunction#1255#42")(::Vector{CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}, ::Base.RefValue{Int64})
from (::var"##dotfunction#1255#42")(x1, x2) in Main
Arguments
#self#::Core.Const(var"##dotfunction#1255#42"())
x1::Vector{CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}}
x2::Base.RefValue{Int64}
Body::Any
1 β %1 = Base.getproperty(Main.pf, :fun)::Any
β %2 = Core.tuple(%1)::Tuple{Any}
β %3 = Core.tuple(x2)::Tuple{Base.RefValue{Int64}}
β %4 = Core._apply_iterate(Base.iterate, Base.broadcasted, %2, x1, %3)::Any
β %5 = Base.materialize(%4)::Any
βββ return %5
So is there a way to implement this parametric function type correctly for CuArray
broadcasting? Thank you in advance!
Update-- even if I define my parameteric function pf
as a const
, the broadcast still fails.