Creating AD-friendly functions

Hello everyone.
I’m in the process of creating a chained composition of functions, similar to a Flux or Lux network, which I’ll later differentiate for the purpose of solving differential equations. Following the discussion in this thread, I was able to construct such a composition, and it works well; However, it performs very badly under differentiation.

Take the following sample code, adapted from the other thread:

using BenchmarkTools

n = 12
l = 6  

act(x) = cos(x) # Activation function analogue
funs = tuple([act for i in 1:l]...)
x0 = rand(n) # Input
W0 = tuple([rand(n,n) for i in 1:l]...) # Weight matrices analogue

#Construct the individual layers
function buildLayers(Funs)
    return [(x,M) -> M*x .|> Funs[i] for i in 1:length(Funs)]
end

#Construct the network
function Net(Funs)
    layers = buildLayers(Funs)
    return net(x,T) = foldl(|>, [y -> layers[i](y,T[i]) for i in 1:length(Funs)] , init = x)
end 

myNet = Net(funs)
@benchmark myNet(x0,W0)
@code_warntype myNet(x0,W0)

The resulting function runs fast and is type stable:

Results
BenchmarkTools.Trial: 10000 samples with 18 evaluations.
 Range (min … max):  944.444 ns … 40.394 μs  ┊ GC (min … max): 0.00% … 94.33%
 Time  (median):       1.011 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.167 μs ±  1.217 μs  ┊ GC (mean ± σ):  3.04% ±  2.95%

  ▂▆██▆▃▁     ▁▁                                ▁▁▁▁           ▂
  ███████▇▅▅▄▆███▇▇▇▆▅▆▆▄▄▅▅▄▄▄▅▃▅▅▅▆▆▇▇▇██████▇████████▇▇▇▇▇▆ █
  944 ns        Histogram: log(frequency) by time      2.23 μs <

 Memory estimate: 2.59 KiB, allocs estimate: 13.
MethodInstance for (::var"#net#28"{NTuple{6, typeof(act)}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}})(::Vector{Float64}, ::NTuple{6, Matrix{Float64}})
  from (::var"#net#28")(x, T) @ Main Untitled-1:20
Arguments
  #self#::var"#net#28"{NTuple{6, typeof(act)}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
  x::Vector{Float64}
  T::NTuple{6, Matrix{Float64}}
Locals
  #26::var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
Body::Vector{Float64}
1 ─ %1  = Main.:(var"#26#29")::Core.Const(var"#26#29")
│   %2  = Core.typeof(T)::Core.Const(NTuple{6, Matrix{Float64}})
│   %3  = Core.getfield(#self#, :layers)::NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}
│   %4  = Core.typeof(%3)::Core.Const(NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}})
│   %5  = Core.apply_type(%1, %2, %4)::Core.Const(var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}})        
│   %6  = Core.getfield(#self#, :layers)::NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}
│         (#26 = %new(%5, T, %6))
│   %8  = #26::var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
│   %9  = Core.getfield(#self#, :Funs)::Core.Const((act, act, act, act, act, act))
│   %10 = Main.length(%9)::Core.Const(6)
│   %11 = (1:%10)::Core.Const(1:6)
│   %12 = Base.Generator(%8, %11)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}}, Any[var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}, Core.Const(1:6)])     
│   %13 = Base.collect(%12)::Vector{var"#27#30"{Int64, NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}}
│   %14 = (:init,)::Core.Const((:init,))
│   %15 = Core.apply_type(Core.NamedTuple, %14)::Core.Const(NamedTuple{(:init,)})
│   %16 = Core.tuple(x)::Tuple{Vector{Float64}}
│   %17 = (%15)(%16)::NamedTuple{(:init,), Tuple{Vector{Float64}}}
│   %18 = Core.kwcall(%17, Main.foldl, Main.:|>, %13)::Vector{Float64}
└──       return %18

For comparison, it runs faster and allocates less than a nearly identical Lux network:

using Random, Lux 

luxLayerTuple = tuple([Dense(n => n, act) for i in 1:l]...)
luxNet = Chain(luxLayerTuple)
rng = Random.default_rng()
pp, st = Lux.setup(rng,luxNet)
@benchmark luxNet(x0,pp,st)
@code_warntype luxNet(x0,pp,st)
Results
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.200 μs … 201.660 μs  ┊ GC (min … max): 0.00% … 97.77%
 Time  (median):     1.270 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.626 μs ±   4.177 μs  ┊ GC (mean ± σ):  5.56% ±  2.19%

  ██▄▁  ▃▂ ▂▃▄▃▃▂▂▁  ▂▃▁                                      ▂
  ████▆█████████████████▇█▇▇▆▇▆▆▅▆▆▆▆▆▇▅▅▆▄▅▄▅▃▅▃▃▅▄▅▂▄▅▂▄▄▂▄ █
  1.2 μs       Histogram: log(frequency) by time       4.3 μs <

 Memory estimate: 2.36 KiB, allocs estimate: 25.

However, when it comes to differentiating, the performance is abysmal: It takes takes almost 4x as much time to differentiate myNet with regards to its parameters:

using Zygote

myGrad(x,W) = Zygote.gradient(T -> myNet(x,T) |> first,W)
@benchmark myGrad(x0,W0)

luxGrad(x,p,s) = Zygote.gradient(P -> luxNet(x,P,s) |> first |> first, p)
@benchmark luxGrad(x0,pp,st)
Results
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   92.800 μs …   3.223 ms  ┊ GC (min … max): 0.00% … 91.39%
 Time  (median):      99.800 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   126.026 μs ± 126.080 μs  ┊ GC (mean ± σ):  4.66% ±  4.63%

  ▄█▇▅▃▂▂▂▁▂▂▁▁▁▁▁     ▁▂▃▄▅▄▂▁▁                                ▂
  ████████████████████████████████▇▇▇▇▇▆▇▇▇▇▇▆▆▆▅▅▆▆▄▅▆▅▅▅▄▄▅▆▆ █
  92.8 μs       Histogram: log(frequency) by time        252 μs <

 Memory estimate: 111.77 KiB, allocs estimate: 718.
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  17.600 μs …   3.595 ms  ┊ GC (min … max):  0.00% … 98.85%
 Time  (median):     22.300 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   42.295 μs ± 133.607 μs  ┊ GC (mean ± σ):  11.70% ±  3.94%

  █▇▅▃▂         ▁▂▂▁                                           ▁
  ███████▇▇▆▅▄▆▇██████▇▇▅▆▄▃▃▄▄▄▄▄▅▄▄▄▄▅▄▄▅▃▅▅▄▄▃▅▅▅▅▅▄▃▃▅▄▃▃▄ █
  17.6 μs       Histogram: log(frequency) by time       316 μs <

 Memory estimate: 71.55 KiB, allocs estimate: 244.

Naturally, both Flux and Lux networks have been brilliantly optimized for this sort of thing. Still, I would like to understand what makes a function AD-friendly, as it’s clear that speed and type-stability are not the only factors at play here.

How else can I optimize this framework in order to make it suitable for differentiation?

Arrays of functions (as opposed to tuples or compositions) are not type stable. (Why not just compose your functions directly rather than calling foldl on arrays of functions?)

2 Likes

Hey steven, thanks for the input. I have made tests with buildLayers returning a tuple, but the benchmarks show very little difference (speed remains the same, code_warntype looks similar). My attempt at an implementation might be silly however:

function buildLayers(Funs)
    return tuple([(x,M) -> M*x .|> Funs[i] for i in 1:length(Funs)]...)
end

Regarding your second point: I need the Net constructor to be able to deal with arbitrarily deep compositions and to correctly give each layer their inputs.

Given that each layer will receive external inputs besides the result of the preceding layers, I wouldn’t know how to calculate this composition beforehand in a more type-stable way. As an example, my final implementation should look something like this:

function NetFinal(Funs)
    layers = buildLayersFinal(Funs)
    len = length(Funs)
    return net(t,x,T) = foldl(|>, [y -> layers[i](y,t,T[i],T[len+1-i]) for i in 1:len] , init = x)
end 

I’m really not sure how to achieve this without iterating over some collection such as an Array.