Direct Feedback Alignment

I am trying to implement a form of Direct Feedback Aligment that allows for custom derivatives based on this paper using Flux

The code seems to work correctly and fast if I pass the data points individually, but it becomes incredibly slow if I try to pass the data in batches using a dataloader. I based my work on @ Rasmus_Hoier implementation

The custom layers are defines as

struct DenseFA{F,G,S,T}
  weight::S
  bias::T
  dfa_weight::S
  σ::F
  g::G
end

function DenseFA(in::Integer, out::Integer,outy::Integer, σ::Function=identity;
initW=Flux.glorot_uniform,initb=Flux.glorot_uniform,dσ=nothing) 
    if isnothing(dσ)
        g(x)=ForwardDiff.derivative((x)->σ(x), x)
    else
        g=dσ
    end
    return DenseFA(initW(out, in),initb(out),initW(out,outy), σ,(x)->g(x))
end
(m::DenseFA)(x) = nonlin_DFA(matmul_DFA(m.weight, x,m.dfa_weight).+m.bias,m.σ, m.g)

Flux.@functor DenseFA (weight,bias)

function matmul_DFA(W, X, B)
    return W * X
end

function ChainRulesCore.rrule(::typeof(matmul_DFA), W::AbstractMatrix, X::AbstractMatrix,B::AbstractMatrix)
    y = W*X
    function matmul_blocked_∂x_pullback(ΔΩ::AbstractMatrix)
        ∂W = @thunk(B* ΔΩ * X')
        return (NoTangent(), ∂W, NoTangent(), NoTangent())
    end
    return y, matmul_blocked_∂x_pullback
end

function nonlin_DFA(X, f,g)
    return f.(X)
end

function ChainRulesCore.rrule(::typeof(nonlin_DFA), X::AbstractMatrix,f::Function,df::Function)
    y = f.(X)
    @debug("Compute rrule nonlin_DFA")
    function nonlin_pullback(ΔΩ::AbstractMatrix)
        ∂f = @thunk(df.(X).*ΔΩ)
        return (NoTangent(), ∂f, NoTangent(), NoTangent())
    end
    return y, nonlin_pullback
end

If I define a chain containing DenseFA layers and train it passing the data element by element, the code works perfectly (i.e. a batch of size 1) but when I use larger batched the code slows down to an halt.

The first rrule is an extension of the matmul_blocked given in the Bender.jl package that allows me to pass the derivative of the error to the previous layers instead of the value of the upsteam layer so that

𝑒(𝑙)=[𝐵(𝑙)𝑒(𝐿)]⊙𝑔(𝑎(𝑙))