using StaticArrays, LinearAlgebra, BenchmarkTools, StructArrays, LoopVectorization, VectorizationBase
const T = Float32
const Point = SVector{3, T}
@kwdef struct Ray
origin::Point = zeros(Point)
direction::Point = Point(0, 1, 0)
end
struct Sphere
centre::Point
radius::T
end
@kwdef struct Scene{T}
Sphere::T = []
end
function StructArrays.staticschema(::Type{Point})
# Define the desired names and eltypes of the "fields"
return NamedTuple{(:x, :y, :z), fieldtypes(Point)...}
end
StructArrays.component(m::Point, key::Symbol) = getproperty(m, key)
StructArrays.createinstance(::Type{Point}, args...) = Point(args)
N = 296
spheres = Sphere.([Point(1, 1, 1) * v for v in range(-100, 100, N)], range(5, 50, N))
scene = Scene(StructArray(spheres, unwrap = T -> !(T<:Real)));
ray = Ray()
@inline function maybecompute(
neg_half_b::VectorizationBase.Vec{W,T},
quarter_discriminant::VectorizationBase.Vec{W,T},
tmax::VectorizationBase.Vec{W,T},
tmin::VectorizationBase.Vec{W,T}
) where {W,T}
m = quarter_discriminant > 0
!VectorizationBase.vany(m) && return tmax
sqrtd = sqrt(quarter_discriminant) # When using fastmath, negative values just give NaN
root = neg_half_b - sqrtd
root2 = neg_half_b + sqrtd
t = ifelse(root > tmin, root, root2)
t = ifelse(m & (tmin < t), t, tmax)
return t
end
@inline function maybecompute(x::VecUnroll, y::VecUnroll, z::VecUnroll, tmin::VectorizationBase.Vec{W,T}) where {W, T}
VecUnroll(
VectorizationBase.fmap(
maybecompute,
VectorizationBase.data(x),
VectorizationBase.data(y),
VectorizationBase.data(z),
tmin,
)
)
end
function findSceneIntersection_maybecompute(ray, hittable_list, tmin::T, tmax::T)
besti::Int32 = 0
x = VectorizationBase.Vec{8,T}(tmin)
@turbo for i in eachindex(hittable_list.Sphere)
cox = hittable_list.Sphere.centre.x[i] - ray.origin.x
coy = hittable_list.Sphere.centre.y[i] - ray.origin.y
coz = hittable_list.Sphere.centre.z[i] - ray.origin.z
neg_half_b = ray.direction.x * cox + ray.direction.y * coy + ray.direction.z * coz
c = cox^2 + coy^2 + coz^2 - hittable_list.Sphere.radius[i]^2
quarter_discriminant = neg_half_b^2 - c
t = maybecompute(neg_half_b, quarter_discriminant, tmax, x)
newMinT = t < tmax
tmax = ifelse(newMinT, t, tmax)
besti = ifelse(newMinT, i, besti)
end
return tmax, besti
end
findSceneIntersection_maybecompute(ray, scene, T(1e-4), T(Inf))
@benchmark findSceneIntersection_maybecompute($ray, $scene, $(T(1e-4)), $(T(Inf)))
using SIMD
@generated function getBits(mask::SIMD.Vec{N, Bool}) where N #This reverses the bits
s = """
%mask = trunc <$N x i8> %0 to <$N x i1>
%res = bitcast <$N x i1> %mask to i$N
ret i$N %res
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, UInt8, Tuple{SIMD.LVec{N, Bool}}, mask.data)
)
end
function hor_min(x::SIMD.Vec{8, T}) where T
@fastmath a = shufflevector(x, Val((4, 5, 6, 7, :undef, :undef, :undef, :undef))) # high half
@fastmath b = min(a, x)
@fastmath a = shufflevector(b, Val((2, 3, :undef, :undef, :undef, :undef, :undef, :undef)))
@fastmath b = min(a, b)
@fastmath a = shufflevector(b, Val((1, :undef, :undef, :undef, :undef, :undef, :undef, :undef)))
@fastmath b = min(a, b)
return @inbounds b[1]
end
SIMD.Intrinsics.add(x::NTuple{8, VecElement{Int32}}, y::NTuple{8, VecElement{Int32}}, ::SIMD.Intrinsics.FastMathFlags{128}) = SIMD.Intrinsics.add(x, y)
@fastmath function findSceneIntersection(r, hittable_list, tmin, tmax)
N = 8
hitT = SIMD.Vec{N, T}(tmax)
laneIndices = SIMD.Vec{N, Int32}(Int32.((1, 2, 3, 4, 5, 6, 7, 8)))
minIndex = SIMD.Vec{N, Int32}(0)
lane = VecRange{N}(1)
@inbounds @fastmath while lane.i <= length(hittable_list.Sphere)
cox = hittable_list.Sphere.centre.x[lane] - r.origin.x
coy = hittable_list.Sphere.centre.y[lane] - r.origin.y
coz = hittable_list.Sphere.centre.z[lane] - r.origin.z
neg_half_b = r.direction.x * cox + r.direction.y * coy
neg_half_b += r.direction.z * coz
c = cox*cox + coy*coy
c += coz*coz
c -= hittable_list.Sphere.radius[lane] * hittable_list.Sphere.radius[lane]
quarter_discriminant = neg_half_b^2 - c
isDiscriminantPositive = quarter_discriminant > 0
if any(isDiscriminantPositive)
@fastmath sqrtd = sqrt(quarter_discriminant) # When using fastmath, negative values just give 0
root = neg_half_b - sqrtd
root2 = neg_half_b + sqrtd
t = vifelse(root > tmin, root, root2)
newMinT = isDiscriminantPositive & (tmin < t) & (t < hitT)
hitT = vifelse(newMinT, t, hitT)
minIndex = vifelse(newMinT, laneIndices, minIndex)
end
laneIndices += N
lane += N
end
minHitT = hor_min(hitT)
if minHitT < tmax
@inbounds i = minIndex[trailing_zeros(getBits(hitT == minHitT)) + 1]
return minHitT, i
else
return minHitT, Int32(0)
end
end
findSceneIntersection(ray, scene, T(1e-4), T(Inf))
@benchmark findSceneIntersection($ray, $scene, $(T(1e-4)), $(T(Inf)))
The median/mean time for findSceneIntersection_maybecompute
was 105ns and 82ns for findSceneIntersection
. I don’t really see why findSceneIntersection_maybecompute
is slower, maybe the reduction from the vector of besti
to a scalar is slow though I would be surprised if it was 20ns slower. There’s also that findSceneIntersection_maybecompute
has an extra ifelse
, I’m a bit confused by what’s going on in the llvm ir for that
L194: ; preds = %L185, %L127
%value_phi24 = phi <8 x float> [ %res.i320, %L185 ], [ %value_phi1454, %L127 ]
%m.i317 = fcmp reassoc nsz arcp contract ogt <8 x float> %res.i336, zeroinitializer
%res.i316 = sext <8 x i1> %m.i317 to <8 x i32>
%res.i314 = bitcast <8 x i32> %res.i316 to <8 x float>
%41 = call i32 @llvm.x86.avx.vtestz.ps.256(<8 x float> %res.i314, <8 x float> %res.i314)
%.not428 = icmp eq i32 %41, 0
br i1 %.not428, label %L206, label %L219
L206: ; preds = %L194
%res.i313 = call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %res.i336)
%res.i312 = fsub nsz contract <8 x float> %res.i363, %res.i313
%res.i311 = fadd nsz contract <8 x float> %res.i363, %res.i313
%m.i309 = fcmp reassoc nsz arcp contract ogt <8 x float> %res.i312, %29
%res.i308 = select reassoc nsz arcp contract <8 x i1> %m.i309, <8 x float> %res.i312, <8 x float> %res.i311
%m.i305 = fcmp reassoc nsz arcp contract olt <8 x float> %29, %res.i308
%combinedmask1.i304429 = and <8 x i1> %m.i305, %m.i317
%res.i303 = select reassoc nsz arcp contract <8 x i1> %combinedmask1.i304429, <8 x float> %res.i308, <8 x float> %value_phi2456
br label %L219
L219: ; preds = %L206, %L194
%value_phi26 = phi <8 x float> [ %res.i303, %L206 ], [ %value_phi2456, %L194 ]
%m.i300 = fcmp reassoc nsz arcp contract olt <8 x float> %value_phi24, %value_phi1454
%m.i298 = fcmp reassoc nsz arcp contract olt <8 x float> %value_phi26, %value_phi2456
%res.i297 = select reassoc nsz arcp contract <8 x i1> %m.i300, <8 x float> %value_phi24, <8 x float> %value_phi1454
%res.i295 = select reassoc nsz arcp contract <8 x i1> %m.i298, <8 x float> %value_phi26, <8 x float> %value_phi2456
%42 = trunc i64 %value_phi453 to i32
%ie.i290 = insertelement <8 x i32> undef, i32 %42, i64 0
%v.i291 = shufflevector <8 x i32> %ie.i290, <8 x i32> undef, <8 x i32> zeroinitializer
%res.i292 = add nsw <8 x i32> %v.i291, <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%43 = add i32 %42, 8
%ie.i287 = insertelement <8 x i32> undef, i32 %43, i64 0
%v.i288 = shufflevector <8 x i32> %ie.i287, <8 x i32> undef, <8 x i32> zeroinitializer
%res.i289 = add nsw <8 x i32> %v.i288, <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
%res.i286 = select <8 x i1> %m.i300, <8 x i32> %res.i292, <8 x i32> %value_phi20458
%res.i284 = select <8 x i1> %m.i298, <8 x i32> %res.i289, <8 x i32> %value_phi21459
%res.i282 = add nuw nsw i64 %value_phi453, 16
%.not460 = icmp sgt i64 %res.i282, %res.i281
br i1 %.not460, label %L234, label %L127
L234: ; preds = %L219
%m.i279 = fcmp reassoc nsz arcp contract olt <8 x float> %res.i297, %res.i295
%res.i278 = select reassoc nsz arcp contract <8 x i1> %m.i279, <8 x float> %res.i297, <8 x float> %res.i295
%res.i272 = select <8 x i1> %m.i279, <8 x i32> %res.i286, <8 x i32> %res.i284
br label %L242
The block L219 seems like a lot of extra instructions which the SIMD version doesn’t include.