I am trying to implement a form of Direct Feedback Aligment that allows for custom derivatives based on this paper using Flux
The code seems to work correctly and fast if I pass the data points individually, but it becomes incredibly slow if I try to pass the data in batches using a dataloader. I based my work on @ Rasmus_Hoier implementation
The custom layers are defines as
struct DenseFA{F,G,S,T}
weight::S
bias::T
dfa_weight::S
σ::F
g::G
end
function DenseFA(in::Integer, out::Integer,outy::Integer, σ::Function=identity;
initW=Flux.glorot_uniform,initb=Flux.glorot_uniform,dσ=nothing)
if isnothing(dσ)
g(x)=ForwardDiff.derivative((x)->σ(x), x)
else
g=dσ
end
return DenseFA(initW(out, in),initb(out),initW(out,outy), σ,(x)->g(x))
end
(m::DenseFA)(x) = nonlin_DFA(matmul_DFA(m.weight, x,m.dfa_weight).+m.bias,m.σ, m.g)
Flux.@functor DenseFA (weight,bias)
function matmul_DFA(W, X, B)
return W * X
end
function ChainRulesCore.rrule(::typeof(matmul_DFA), W::AbstractMatrix, X::AbstractMatrix,B::AbstractMatrix)
y = W*X
function matmul_blocked_∂x_pullback(ΔΩ::AbstractMatrix)
∂W = @thunk(B* ΔΩ * X')
return (NoTangent(), ∂W, NoTangent(), NoTangent())
end
return y, matmul_blocked_∂x_pullback
end
function nonlin_DFA(X, f,g)
return f.(X)
end
function ChainRulesCore.rrule(::typeof(nonlin_DFA), X::AbstractMatrix,f::Function,df::Function)
y = f.(X)
@debug("Compute rrule nonlin_DFA")
function nonlin_pullback(ΔΩ::AbstractMatrix)
∂f = @thunk(df.(X).*ΔΩ)
return (NoTangent(), ∂f, NoTangent(), NoTangent())
end
return y, nonlin_pullback
end
If I define a chain containing DenseFA layers and train it passing the data element by element, the code works perfectly (i.e. a batch of size 1) but when I use larger batched the code slows down to an halt.
The first rrule is an extension of the matmul_blocked given in the Bender.jl package that allows me to pass the derivative of the error to the previous layers instead of the value of the upsteam layer so that
𝑒(𝑙)=[𝐵(𝑙)𝑒(𝐿)]⊙𝑔(𝑎(𝑙))