Efficiently constructing a composition of functions

Hello everyone.

I’m trying to build a constructor that (essentially) takes in an array of functions [f1 f2 f3] and returns their function composition f1 |> f2 |> f3. In this regard, what I want is similar to Flux’s or Lux’s Chain, but later my usage will branch into some functionalities these libraries do not support, so I’m trying to recreate some of this stuff adapted to my own context.

A naive implementation would go as follows:

using Random, BenchmarkTools

n = 12
l = 6  
input = rand(n)

f(x) = tanh(x) # Activation function analogue
funs = [f for i in 1:l] #Vector with activation functions

T = [rand(n,n) for i in 1:l]   # Weight matrices analogue
myMul(M,x) = M*x

#Construct the individual layers
function buildFoldable(Tensor,Funs)
    return [y -> myMul(Tensor[i], y) .|> y -> Funs[i](y) for i in 1:length(Tensor)]
end

#Construct the network
function Net(Tensor,Funs)
    layers = buildFoldable(Tensor,Funs)
    return net(x) = foldl(|>, [x, layers...])
end 

myNet = Net(T,funs)
@benchmark myNet(input)

Perhaps unsurprisingly, this is very slow (7 times slower than a similar Flux network) and type-unstable:

BenchmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min … max):  5.667 μs … 454.150 μs  ┊ GC (min … max): 0.00% … 95.06%
 Time  (median):     6.333 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.305 μs ±   7.581 μs  ┊ GC (mean ± σ):  1.55% ±  1.65%

  ▆█▇▇▇▆▃▄▅▆▇▅▂▂▂▁      ▁▁▁    ▁ ▁▁▁ ▁      ▁▁▂▂▂▂▁▁▁         ▃
  █████████████████▇█▇█▇███▇▇▆████████████▇████████████▇▆▆▅▅▅ █
  5.67 μs      Histogram: log(frequency) by time      13.7 μs <

 Memory estimate: 2.55 KiB, allocs estimate: 28.

I suspect this has to do with interactions between different scopes and the compiler having to fetch f and T during runtime, but I don’t really know how to stop that from happening, aside perhaps from declaring T and funs as consts. Still, the ideal would be for the value of T and funs to be “imprinted” into net when it’s compiled, but I’m not sure how to do that.

Can anyone share some advice? I feel like understanding this would also deepen my understanding of Julia as a whole.

First thing to check here is type stability. If a function is not type stable then it will cause huge slowdowns as types have to be determined at runtime. By running @code_warntype we can see:

julia> @code_warntype myNet(input)
MethodInstance for (::var"#net#13"{Vector{var"#8#11"{Int64, Vector{Matrix{Float64}}, Vector{typeof(f)}}}})(::Vector{Float64})
  from (::var"#net#13")(x) @ Main REPL[10]:4
Arguments
  #self#::var"#net#13"{Vector{var"#8#11"{Int64, Vector{Matrix{Float64}}, Vector{typeof(f)}}}}
  x::Vector{Float64}
Body::Any
1 ─ %1 = Main.:|>::Core.Const(|>)
│   %2 = Core.tuple(x)::Tuple{Vector{Float64}}
│   %3 = Core.getfield(#self#, :layers)::Vector{var"#8#11"{Int64, Vector{Matrix{Float64}}, Vector{typeof(f)}}}
│   %4 = Core._apply_iterate(Base.iterate, Base.vect, %2, %3)::Any
│   %5 = Main.foldl(%1, %4)::Any
└──      return %5

This is type unstable because the type of [x, layers...] is Any because x is a number and layers are functions. The type hierarchy dictates that the lowest-common supertype of a number and a function is Any, and so your return type from the fold is Any as well.

1 Like

This is an abstractly typed container. Can you use a tuple?

2 Likes

The foldl with the mixed type container is indeed the culprit. It is unnecessary though as foldl allows to pass an initial element explicitly: net(x) = foldl(|>, layers; init = x) gives a nice speedup on my machine.

2 Likes

Indeed! I was not aware there was an init field. All runs smoothly now.

Thank you all very much! I would never have guessed.