Hello everyone.
I’m trying to build a constructor that (essentially) takes in an array of functions [f1 f2 f3]
and returns their function composition f1 |> f2 |> f3
. In this regard, what I want is similar to Flux’s or Lux’s Chain
, but later my usage will branch into some functionalities these libraries do not support, so I’m trying to recreate some of this stuff adapted to my own context.
A naive implementation would go as follows:
using Random, BenchmarkTools
n = 12
l = 6
input = rand(n)
f(x) = tanh(x) # Activation function analogue
funs = [f for i in 1:l] #Vector with activation functions
T = [rand(n,n) for i in 1:l] # Weight matrices analogue
myMul(M,x) = M*x
#Construct the individual layers
function buildFoldable(Tensor,Funs)
return [y -> myMul(Tensor[i], y) .|> y -> Funs[i](y) for i in 1:length(Tensor)]
end
#Construct the network
function Net(Tensor,Funs)
layers = buildFoldable(Tensor,Funs)
return net(x) = foldl(|>, [x, layers...])
end
myNet = Net(T,funs)
@benchmark myNet(input)
Perhaps unsurprisingly, this is very slow (7 times slower than a similar Flux network) and type-unstable:
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
Range (min … max): 5.667 μs … 454.150 μs ┊ GC (min … max): 0.00% … 95.06%
Time (median): 6.333 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.305 μs ± 7.581 μs ┊ GC (mean ± σ): 1.55% ± 1.65%
▆█▇▇▇▆▃▄▅▆▇▅▂▂▂▁ ▁▁▁ ▁ ▁▁▁ ▁ ▁▁▂▂▂▂▁▁▁ ▃
█████████████████▇█▇█▇███▇▇▆████████████▇████████████▇▆▆▅▅▅ █
5.67 μs Histogram: log(frequency) by time 13.7 μs <
Memory estimate: 2.55 KiB, allocs estimate: 28.
I suspect this has to do with interactions between different scopes and the compiler having to fetch f
and T
during runtime, but I don’t really know how to stop that from happening, aside perhaps from declaring T
and funs
as const
s. Still, the ideal would be for the value of T
and funs
to be “imprinted” into net
when it’s compiled, but I’m not sure how to do that.
Can anyone share some advice? I feel like understanding this would also deepen my understanding of Julia as a whole.