I think the best would be to give an example, it can even be something in the form of the Julia challenge / repo as I was the one who implemented the Nim solution.
When code selection depends on compile-time types or values, either I overload or I use Nim compile-time facilities.
Example 1: static dispatching float32/float64 to the proper CuBLAS call with when
(compile-time if)
# Vector copy
proc cublas_copy*[T: SomeFloat](
n: int; x: ptr T; incx: int;
y: ptr T; incy: int) {.inline.}=
check cublasSetStream(cublasHandle0, cudaStream0)
when T is float32:
check cublasScopy(cublasHandle0, n.cint, x, incx.cint, y, incy.cint)
elif T is float64:
check cublasDcopy(cublasHandle0, n.cint, x, incx.cint, y, incy.cint)
Example 2: static dispatching/specialization rounding to next/previous multiple of N, depending if multiple of 2 or not:
func round_step_down*(x: Natural, step: static Natural): int {.inline.} =
## Round the input to the previous multiple of "step"
when (step and (step - 1)) == 0:
# Step is a power of 2. (If compiler cannot prove that x>0 it does not make the optim)
result = x and not(step - 1)
else:
result = x - x mod step
func round_step_up*(x: Natural, step: static Natural): int {.inline.} =
## Round the input to the next multiple of "step"
when (step and (step - 1)) == 0:
# Step is a power of 2. (If compiler cannot prove that x>0 it does not make the optim)
result = (x + step - 1) and not(step - 1)
else:
result = ((x + step - 1) div step) * step
Example 3: generic algorithms on various SIMD:
import simd # this is a SIMD file like https://github.com/numforge/laser/blob/2f619fdbb2496aa7a5e5538035a8d42d88db8c10/laser/simd.nim
proc `+`(a, b: m128): m128 =
mm_add_ps(a, b)
proc `+`(a, b: m256): m256 =
mm256_add_ps(a, b)
proc `+`(a, b: m512): m512 =
mm512_add_ps(a, b)
Furthermore, besides specialization via when
or overloading, you can detect pattern and use “super-instructions”, some useful pattern would be matrix A B + C
or exp(m-1)
or ln(1 + p)
which have a specific implementation in BLAS for the first and math.h for the last 2.
You can do that with term-rewriting macros/templates, for example:
# Implementation of the fused operation (out-of-place)
# ------------------------------------------------------------------
proc tensor_multiplyAdd[T](
A, B: Tensor[T],
C: Tensor[T]): Tensor[T] =
result = C
if C.rank == 2:
gemm(1.T, A, B, 1.T, result)
elif C.rank == 1:
gemv(1.T, A, B, 1.T, result)
else:
raise newException(ValueError, "Matrix-Matrix or Matrix-Vector multiplication valid only if first Tensor is a Matrix and second is a Matrix or Vector")
# Implementation of the fused operation (in-place)
# ------------------------------------------------------------------
proc tensor_multiplyAdd_inplace[T](
A, B: Tensor[T],
C: var Tensor[T]) =
if C.rank == 2:
gemm(1.T, A, B, 1.T, C)
elif C.rank == 1:
gemv(1.T, A, B, 1.T, C)
else:
raise newException(ValueError, "Matrix-Matrix or Matrix-Vector multiplication valid only if first Tensor is a Matrix and second is a Matrix or Vector")
# Pattern match for fusion operation (out-of-place)
# ------------------------------------------------------------------
template rewriteTensor_MultiplyAdd*{`*`(A,B) + C}[T](
A, B, C: Tensor[T]): auto =
## Fuse ``A*B + C`` into a single operation.
##
## Operation fusion leverage the Nim compiler and should not be called explicitly.
tensor_multiplyAdd(A, B, C)
# Pattern match for fusion operation (different arg order)
# ---------------------------------------------------------------------
template rewriteTensor_MultiplyAdd*{C + `*`(A,B)}[T](
A, B, C: Tensor[T]): auto =
## Fuse ``C + A * B`` into a single operation.
##
## Operation fusion leverage the Nim compiler and should not be called explicitly.
tensor_multiplyAdd(A, B, C)
# Pattern match for fusion operation (in-place)
# ---------------------------------------------------------------------
template rewriteTensor_MultiplyAdd_inplace*{C += `*`(A,B)}[T](
A, B: Tensor[T], C: var Tensor[T]) =
## Fuse ``C+=A*B`` into a single operation.
##
## Operation fusion leverage the Nim compiler and should not be called explicitly.
tensor_multiplyAdd_inplace(A, B, C)
A very involved example is my reimplementation of a BLAS in pure Nim (+ intrinsics) that reach OpenBLAS and MKL-DNN performance on large matrices:
The dispatch I used there is a mix of dispatch depending on:
- types (int32, int64, float32, float64, …)
- CPU Architecture (x86, ARM)
- CPU features (SSE2, SSE3, AVX, AVX2, AVX512, …)
- Number of registers
- Register size
- is the result matrix unit-strided
see the dispatch logic.
After abstracting away all those architectures, registers and low-level differences, I have a very lean code-generator that can be expanded quickly to any CPU/SIMD combination and provide the same speed as Assembly-coded BLAS routines by OpenBLAS and MKL experts:
import
./gemm_ukernel_generator, ./gemm_tiling,
../../simd
template float32x4_muladd_unfused(a, b, c: m128): m128 =
mm_add_ps(mm_mul_ps(a, b), c)
ukernel_generator(
x86_SSE,
typ = float32,
vectype = m128,
nb_scalars = 4,
simd_setZero = mm_setzero_ps,
simd_broadcast_value = mm_set1_ps,
simd_load_aligned = mm_load_ps,
simd_load_unaligned = mm_loadu_ps,
simd_store_unaligned = mm_storeu_ps,
simd_mul = mm_mul_ps,
simd_add = mm_add_ps,
simd_fma = float32x4_muladd_unfused
)
import
./gemm_ukernel_generator, ./gemm_tiling,
../../simd
ukernel_generator(
x86_AVX512,
typ = float32,
vectype = m512,
nb_scalars = 16,
simd_setZero = mm512_setzero_ps,
simd_broadcast_value = mm512_set1_ps,
simd_load_aligned = mm512_load_ps,
simd_load_unaligned = mm512_loadu_ps,
simd_store_unaligned = mm512_storeu_ps,
simd_mul = mm512_mul_ps,
simd_add = mm512_add_ps,
simd_fma = mm512_fmadd_ps
)
Supporting integers even if there is no SIMD is also very easy:
int32 BLAS on SSE2-only arch (SSE4.1 brings lots of intrinsics for int32)
type Int32x2 = array[2, int32]
func setZero_int32_sse2_fallback(): Int32x2 {.inline.} =
discard
template set1_int32_sse2_fallback(a: int32): Int32x2 =
[a, a]
func load_int32_sse2_fallback(mem_addr: ptr int32): Int32x2 {.inline.}=
let p = cast[ptr UncheckedArray[int32]](mem_addr)
[p[0], p[1]]
func store_int32_sse2_fallback(mem_addr: ptr int32, a: Int32x2) {.inline.}=
let p = cast[ptr UncheckedArray[int32]](mem_addr)
p[0] = a[0]
p[1] = a[1]
template add_int32_sse2_fallback(a, b: Int32x2): Int32x2 =
[a[0] + b[0], a[1] + b[1]]
template mul_int32_sse2_fallback(a, b: Int32x2): Int32x2 =
[a[0] * b[0], a[1] * b[1]]
template fma_int32_sse2_fallback(a, b, c: Int32x2): Int32x2 =
[c[0] + a[0]*b[0], c[1] + a[1]*b[1]]
ukernel_generator(
x86_SSE2,
typ = int32,
vectype = Int32x2,
nb_scalars = 2,
simd_setZero = setZero_int32_sse2_fallback,
simd_broadcast_value = set1_int32_sse2_fallback,
simd_load_aligned = load_int32_sse2_fallback,
simd_load_unaligned = load_int32_sse2_fallback,
simd_store_unaligned = store_int32_sse2_fallback,
simd_mul = mul_int32_sse2_fallback,
simd_add = add_int32_sse2_fallback,
simd_fma = fma_int32_sse2_fallback
)
On an even more generic note, I am currently writing a DSL and a corresponding deep learning compiler that would work at Nim compile-time (implemented via macro, maybe later as a compiler plugin shipped as .dll/.so) or Nim runtime with LLVM backend.
A preview is visible here. This is a generalized answer to the Julia challenge as well.
# Define the algorithm
# ----------------------------------------
proc foobar(a, b, c: Fn): Fn =
# Iteration Domain
var i, j: Iter
var bar: Fn
# Notice that a normal language would do multiple loop
bar[i, j] = a[i, j] + b[i, j] + c[i, j]
# Update result
result = bar
# Materialize the algorithm for float32
# ----------------------------------------------
generate foobar:
proc foobar(a: Tensor[float32], b, c: Tensor[float32]): Tensor[float32]
# Test run
# ----------------------------------------------
let
u = [[float32 1, 1, 1],
[float32 1, 1, 1],
[float32 1, 1, 1]].toTensor()
v = [[float32 2, 2, 2],
[float32 2, 2, 2],
[float32 2, 2, 2]].toTensor()
w = [[float32 3, 3, 3],
[float32 3, 3, 3],
[float32 3, 3, 3]].toTensor()
let r = foobar(u, v, w)
echo r
Edit - bonus: And I don’t use the Visitor Pattern in my compiler for double dispatch unlike many C++ compilers, I just use ADTs.