Let’s say we have a function that uses Base intrinsics like exp
:
logistic(x) = 1 / (1 + exp(-x))
If we broadcast it over CuArrays, e.g.:
A = cu(rand(5, 10))
logistic.(x)
we expectedly get a warning (and even crashes in more complicated cases):
┌ Warning: calls to Base intrinsics might be GPU incompatible
│ exception =
│ You called exp(x::T) where T<:Union{Float32, Float64} in Base.Math at special/exp.jl:75, maybe you intended to call exp(x::Float32) in CUDAnative at /home/dfdx/.julia/packages/CUDAnative/Mdd3w/src/device/libdevice.jl:90 instead?
│ Stacktrace:
│ [1] exp at special/exp.jl:75
│ [2] #23 at /home/dfdx/.julia/packages/GPUArrays/t8tJB/src/broadcast.jl:49
└ @ CUDAnative ~/.julia/packages/CUDAnative/Mdd3w/src/compiler/irgen.jl:68
5×10 CuArray{Float32,2}:
...
I can of course rewrite logistic
function for CuArrays specifically:
logistic(x) = 1 / (1 + CUDAnative.exp(-x))
But then it’s not generic anymore.
Another option is to override broadcasted
for each function, e.g. something like:
Base.Broadcast.broadcasted(::typeof(logistic), x) = logistic.(x)
Base.Broadcast.broadcasted(::typeof(logistic), x::CuArray) = 1 ./ (1 .+ CUDAnative.exp.(-x))
But it would be pretty hard to extend to all such functions.
My current plan is to use Cassette.jl to rewrite all intrinsic function on the fly, however I may be missing some obviously simpler solution. Do I?