# Custom rrule for Feedback Alignment

Hi,
I am trying to use a custom rrule to implement Feedback Alignment learning (which is just like BP except that the error signals are propagated backwards using random feedback weights B, which is supposed to be more biologically plausible than using W both in the forwards and backwards pass).

I made a simplified layer DenseFA (without nonlinearity and bias for simplicity), along with a custom rrule. the rrule seems to work when called on its own, but when I use Flux gradient function it is no longer used.

Boilerplate: imports and layer definition
``````using Flux; using Flux: @functor, glorot_uniform
using ChainRulesCore; using ChainRulesCore: NoTangent, @thunk

# DenseFA struct
struct DenseFA{M1<:Matrix, M2<:Matrix}
W::M1 # Weights used in the forward pass
B::M2 # Weights used in the backwards pass
function DenseFA(W::M1, B::M2) where {M1<:Matrix, M2<:Matrix}
new{M1,M2}(W, B)
end
end

# Initialize DenseFA struct
function DenseFA(in::Integer, out::Integer; init = glorot_uniform)
W = init(out, in)
B = init(out, in)
return DenseFA(W, B)
end

@functor DenseFA

(a::DenseFA)(X) = my_matmul(a.W, X, a.B)

function my_matmul(W, X, B)
return W * X
end

function Base.show(io::IO, l::DenseFA)
print(io, "DenseFA(", size(l.W, 2), ", ", size(l.W, 1), ")")
end

``````
The rrule and the gradient computation
``````function rrule(::typeof(my_matmul), W::Matrix, X::Matrix, B::Matrix)
y = my_matmul(W, X, B)

println("==========Using rrule with fixed random feedback weights==========")

function times_pullback(ΔΩ)
∂W = @thunk(ΔΩ * X')
∂X = @thunk(B' * ΔΩ) # Use random feedback weight matrix B
return (NoTangent(), ∂W, ∂X, NoTangent())
end
return y, times_pullback
end

Flux.trainable(a::DenseFA) = (a.W,)

model = Chain(DenseFA(50, 60), DenseFA(60, 40), DenseFA(40, 2))

# dummy data (Batchsize 64)
x = rand(Float32, 50, 64);
y = rand(Float32, 2, 64)

loss = Flux.Losses.mse
opt = Descent(0.1)
ps = Flux.params(model);
gs = gradient(() -> loss(model(x), y), ps)
``````

I could use advice on why the rrule is not being used and how I can make Flux use it.

1 Like

What you’re missing is qualifying (or importing) rrule. At present this defines a new function of that name, instead (as you did for `show`) you need to add methods to the existing `ChainRulesCore.rrule`.

``````julia> rrule
rrule (generic function with 1 method)

julia> function ChainRulesCore.rrule(::typeof(my_matmul), W::Matrix, X::Matrix, B::Matrix)
y = my_matmul(W, X, B)
...
``````
2 Likes

Thanks, I had completely overlooked that!
Now everything works as expected 