I don’t know if I am late or not, but hope this will be a useful addition to a discussion.
Dispatch in a stand-alone function in Python can only be done via an explicit type check. So, to resolve all the type checks, the user must install all the packages implementing types the function author cared to check for.
Another issue with Python’s freestanding functions is performance. A function taking into account only the interface may be grossly inefficient. For example, here is a generic function to compute a standard deviation for an array:
from math import sqrt
def mean_py(vec):
total = 0.0
n = 0
for x in vec:
total += x
n += 1
return total / n
def stddev_py(vec):
ave = mean_py(vec)
disp = 0.0
n = 0
for x in vec:
diff = x - ave
disp += diff * diff
n += 1
return sqrt(disp / (n - 1))
While it works fine for any iterable object, for Numpy arrays another function will be much faster:
import numpy as np
def stddev_numpy(vec):
n = len(vec)
ave = np.sum(vec) / n
diff = vec - ave
disp = np.dot(diff, diff)
return sqrt(disp / (n-1))
Now, some performance tests on my computer for both functions:
from random import random
import array
x = [random() for _ in range(100_000)]
x_np = np.array(x)
xa = array.array('d', x)
%timeit stddev_py(x)
# 11.6 ms ± 357 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit stddev_numpy(x)
# 19.4 ms ± 1.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit stddev_py(x_np)
# 38.3 ms ± 1.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit stddev_numpy(x_np)
# 206 μs ± 83.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit stddev_py(xa)
# 12.6 ms ± 156 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit stddev_numpy(xa)
# 274 μs ± 68.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
So, using numpy slows down operations if data are in a list
, using generic interface is much slower for numpy arrays and stdlib arrays.
And for that reason all the ML frameworks in Python have to reimplement numpy interface, because np.sum
won’t be efficient for PyTorch or JAX arrays.
In Julia, a generic stddev
function will be efficient, as long as optimized sum
and dot
methods are available for a given datatype. Or, if we want to, we can define an optimized stddev
for a specific type, like range.