Engineering a better activation function for best assembly

I’m checking out how complex an activation function can be before it’s (much) slower than ReLU, the simplest kind and standard before new Mish.

Maybe not surprisingly f1 is as fast (while maybe it, f2 and f4 are no good), with as many assembly instructions as ReLU, but f2 with 10 vs 5 instructions (before return instruction) is also as fast, and even f3 with even more instructions, and it’s as fast for all inputs.

It’s a bit dubious, and should this translate to GPUs? I’m new to them, and how to test them and benchmark.

julia> function f1(x)
           return x*abs(x)
       end

julia> x = y = range(-5, 5, length = 40)
julia> Plots.plot(y, f1)

julia> function f2(x)
           one = convert(typeof(x), 1.0)
           return (x+one)*abs(x)-one
       end

julia> function f3(x)
           if x > convert(typeof(x), -1.0)
               x
           else
               convert(typeof(x), -Q_rsqrt(convert(Float32, -x)))
           end
       end

julia> ReLU(x) = min(x, 0)
julia> @code_native ReLU(1.0f0)  # five assembly instructions then return

julia> function f4(x) if x > convert(typeof(x), -1.0) x else log(log(convert(typeof(x), 2.0)-x)+convert(typeof(x), 2.0))-convert(typeof(x), 2.0) end end

julia> @btime f4(-216.0)  # more negative numbers return positive, probably not a good idea, but seems not terrible
  9.322 ns (0 allocations: 0 bytes)
-0.0006174597070864873


julia> function Q_rsqrt( number::Float32 )
               #long i;
               #float x2, y;
               threehalfs = 1.5f0; # const float threehalfs = 1.5F;

               x2 = number * 0.5f0;
               y  = number;
               i = reinterpret(UInt32, y) # i  = * ( long * ) &y;                       // evil floating point b
it level hacking
               i  = 0x5f3759df - ( i >> 1 );               # // what the fuck? 
               y = reinterpret(Float32, i) # y  = * ( float * ) &i;
               y  = y * ( threehalfs - ( x2 * y * y ) );   # // 1st iteration
       #        y  = y * ( threehalfs - ( x2 * y * y ) );    // 2nd iteration, this can be removed

               return y;
       end