Hello everyone.
I’m in the process of creating a chained composition of functions, similar to a Flux or Lux network, which I’ll later differentiate for the purpose of solving differential equations. Following the discussion in this thread, I was able to construct such a composition, and it works well; However, it performs very badly under differentiation.
Take the following sample code, adapted from the other thread:
using BenchmarkTools
n = 12
l = 6
act(x) = cos(x) # Activation function analogue
funs = tuple([act for i in 1:l]...)
x0 = rand(n) # Input
W0 = tuple([rand(n,n) for i in 1:l]...) # Weight matrices analogue
#Construct the individual layers
function buildLayers(Funs)
return [(x,M) -> M*x .|> Funs[i] for i in 1:length(Funs)]
end
#Construct the network
function Net(Funs)
layers = buildLayers(Funs)
return net(x,T) = foldl(|>, [y -> layers[i](y,T[i]) for i in 1:length(Funs)] , init = x)
end
myNet = Net(funs)
@benchmark myNet(x0,W0)
@code_warntype myNet(x0,W0)
The resulting function runs fast and is type stable:
Results
BenchmarkTools.Trial: 10000 samples with 18 evaluations.
Range (min … max): 944.444 ns … 40.394 μs ┊ GC (min … max): 0.00% … 94.33%
Time (median): 1.011 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.167 μs ± 1.217 μs ┊ GC (mean ± σ): 3.04% ± 2.95%
▂▆██▆▃▁ ▁▁ ▁▁▁▁ ▂
███████▇▅▅▄▆███▇▇▇▆▅▆▆▄▄▅▅▄▄▄▅▃▅▅▅▆▆▇▇▇██████▇████████▇▇▇▇▇▆ █
944 ns Histogram: log(frequency) by time 2.23 μs <
Memory estimate: 2.59 KiB, allocs estimate: 13.
MethodInstance for (::var"#net#28"{NTuple{6, typeof(act)}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}})(::Vector{Float64}, ::NTuple{6, Matrix{Float64}})
from (::var"#net#28")(x, T) @ Main Untitled-1:20
Arguments
#self#::var"#net#28"{NTuple{6, typeof(act)}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
x::Vector{Float64}
T::NTuple{6, Matrix{Float64}}
Locals
#26::var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
Body::Vector{Float64}
1 ─ %1 = Main.:(var"#26#29")::Core.Const(var"#26#29")
│ %2 = Core.typeof(T)::Core.Const(NTuple{6, Matrix{Float64}})
│ %3 = Core.getfield(#self#, :layers)::NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}
│ %4 = Core.typeof(%3)::Core.Const(NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}})
│ %5 = Core.apply_type(%1, %2, %4)::Core.Const(var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}})
│ %6 = Core.getfield(#self#, :layers)::NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}
│ (#26 = %new(%5, T, %6))
│ %8 = #26::var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}
│ %9 = Core.getfield(#self#, :Funs)::Core.Const((act, act, act, act, act, act))
│ %10 = Main.length(%9)::Core.Const(6)
│ %11 = (1:%10)::Core.Const(1:6)
│ %12 = Base.Generator(%8, %11)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}}, Any[var"#26#29"{NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}, Core.Const(1:6)])
│ %13 = Base.collect(%12)::Vector{var"#27#30"{Int64, NTuple{6, Matrix{Float64}}, NTuple{6, var"#23#25"{Int64, NTuple{6, typeof(act)}}}}}
│ %14 = (:init,)::Core.Const((:init,))
│ %15 = Core.apply_type(Core.NamedTuple, %14)::Core.Const(NamedTuple{(:init,)})
│ %16 = Core.tuple(x)::Tuple{Vector{Float64}}
│ %17 = (%15)(%16)::NamedTuple{(:init,), Tuple{Vector{Float64}}}
│ %18 = Core.kwcall(%17, Main.foldl, Main.:|>, %13)::Vector{Float64}
└── return %18
For comparison, it runs faster and allocates less than a nearly identical Lux network:
using Random, Lux
luxLayerTuple = tuple([Dense(n => n, act) for i in 1:l]...)
luxNet = Chain(luxLayerTuple)
rng = Random.default_rng()
pp, st = Lux.setup(rng,luxNet)
@benchmark luxNet(x0,pp,st)
@code_warntype luxNet(x0,pp,st)
Results
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.200 μs … 201.660 μs ┊ GC (min … max): 0.00% … 97.77%
Time (median): 1.270 μs ┊ GC (median): 0.00%
Time (mean ± σ): 1.626 μs ± 4.177 μs ┊ GC (mean ± σ): 5.56% ± 2.19%
██▄▁ ▃▂ ▂▃▄▃▃▂▂▁ ▂▃▁ ▂
████▆█████████████████▇█▇▇▆▇▆▆▅▆▆▆▆▆▇▅▅▆▄▅▄▅▃▅▃▃▅▄▅▂▄▅▂▄▄▂▄ █
1.2 μs Histogram: log(frequency) by time 4.3 μs <
Memory estimate: 2.36 KiB, allocs estimate: 25.
However, when it comes to differentiating, the performance is abysmal: It takes takes almost 4x as much time to differentiate myNet
with regards to its parameters:
using Zygote
myGrad(x,W) = Zygote.gradient(T -> myNet(x,T) |> first,W)
@benchmark myGrad(x0,W0)
luxGrad(x,p,s) = Zygote.gradient(P -> luxNet(x,P,s) |> first |> first, p)
@benchmark luxGrad(x0,pp,st)
Results
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 92.800 μs … 3.223 ms ┊ GC (min … max): 0.00% … 91.39%
Time (median): 99.800 μs ┊ GC (median): 0.00%
Time (mean ± σ): 126.026 μs ± 126.080 μs ┊ GC (mean ± σ): 4.66% ± 4.63%
▄█▇▅▃▂▂▂▁▂▂▁▁▁▁▁ ▁▂▃▄▅▄▂▁▁ ▂
████████████████████████████████▇▇▇▇▇▆▇▇▇▇▇▆▆▆▅▅▆▆▄▅▆▅▅▅▄▄▅▆▆ █
92.8 μs Histogram: log(frequency) by time 252 μs <
Memory estimate: 111.77 KiB, allocs estimate: 718.
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 17.600 μs … 3.595 ms ┊ GC (min … max): 0.00% … 98.85%
Time (median): 22.300 μs ┊ GC (median): 0.00%
Time (mean ± σ): 42.295 μs ± 133.607 μs ┊ GC (mean ± σ): 11.70% ± 3.94%
█▇▅▃▂ ▁▂▂▁ ▁
███████▇▇▆▅▄▆▇██████▇▇▅▆▄▃▃▄▄▄▄▄▅▄▄▄▄▅▄▄▅▃▅▅▄▄▃▅▅▅▅▅▄▃▃▅▄▃▃▄ █
17.6 μs Histogram: log(frequency) by time 316 μs <
Memory estimate: 71.55 KiB, allocs estimate: 244.
Naturally, both Flux and Lux networks have been brilliantly optimized for this sort of thing. Still, I would like to understand what makes a function AD-friendly, as it’s clear that speed and type-stability are not the only factors at play here.
How else can I optimize this framework in order to make it suitable for differentiation?