Hello everyone.
I’m trying to build a constructor that (essentially) takes in an array of functions [f1 f2 f3] and returns their function composition f1 |> f2 |> f3. In this regard, what I want is similar to Flux’s or Lux’s Chain, but later my usage will branch into some functionalities these libraries do not support, so I’m trying to recreate some of this stuff adapted to my own context.
A naive implementation would go as follows:
using Random, BenchmarkTools
n = 12
l = 6
input = rand(n)
f(x) = tanh(x) # Activation function analogue
funs = [f for i in 1:l] #Vector with activation functions
T = [rand(n,n) for i in 1:l] # Weight matrices analogue
myMul(M,x) = M*x
#Construct the individual layers
function buildFoldable(Tensor,Funs)
return [y -> myMul(Tensor[i], y) .|> y -> Funs[i](y) for i in 1:length(Tensor)]
end
#Construct the network
function Net(Tensor,Funs)
layers = buildFoldable(Tensor,Funs)
return net(x) = foldl(|>, [x, layers...])
end
myNet = Net(T,funs)
@benchmark myNet(input)
Perhaps unsurprisingly, this is very slow (7 times slower than a similar Flux network) and type-unstable:
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
Range (min … max): 5.667 μs … 454.150 μs ┊ GC (min … max): 0.00% … 95.06%
Time (median): 6.333 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.305 μs ± 7.581 μs ┊ GC (mean ± σ): 1.55% ± 1.65%
▆█▇▇▇▆▃▄▅▆▇▅▂▂▂▁ ▁▁▁ ▁ ▁▁▁ ▁ ▁▁▂▂▂▂▁▁▁ ▃
█████████████████▇█▇█▇███▇▇▆████████████▇████████████▇▆▆▅▅▅ █
5.67 μs Histogram: log(frequency) by time 13.7 μs <
Memory estimate: 2.55 KiB, allocs estimate: 28.
I suspect this has to do with interactions between different scopes and the compiler having to fetch f and T during runtime, but I don’t really know how to stop that from happening, aside perhaps from declaring T and funs as consts. Still, the ideal would be for the value of T and funs to be “imprinted” into net when it’s compiled, but I’m not sure how to do that.
Can anyone share some advice? I feel like understanding this would also deepen my understanding of Julia as a whole.