Custom rrule for Feedback Alignment

Hi,
I am trying to use a custom rrule to implement Feedback Alignment learning (which is just like BP except that the error signals are propagated backwards using random feedback weights B, which is supposed to be more biologically plausible than using W both in the forwards and backwards pass).

I made a simplified layer DenseFA (without nonlinearity and bias for simplicity), along with a custom rrule. the rrule seems to work when called on its own, but when I use Flux gradient function it is no longer used.

Boilerplate: imports and layer definition
using Flux; using Flux: @functor, glorot_uniform
using ChainRulesCore; using ChainRulesCore: NoTangent, @thunk

# DenseFA struct
struct DenseFA{M1<:Matrix, M2<:Matrix}
    W::M1 # Weights used in the forward pass
    B::M2 # Weights used in the backwards pass
    function DenseFA(W::M1, B::M2) where {M1<:Matrix, M2<:Matrix}
        new{M1,M2}(W, B)
    end
end

# Initialize DenseFA struct
function DenseFA(in::Integer, out::Integer; init = glorot_uniform)
   W = init(out, in)
   B = init(out, in)
   return DenseFA(W, B)
end

@functor DenseFA

(a::DenseFA)(X) = my_matmul(a.W, X, a.B)

function my_matmul(W, X, B) 
   return W * X
end

function Base.show(io::IO, l::DenseFA)
   print(io, "DenseFA(", size(l.W, 2), ", ", size(l.W, 1), ")")
end

The rrule and the gradient computation
function rrule(::typeof(my_matmul), W::Matrix, X::Matrix, B::Matrix)
   y = my_matmul(W, X, B)

   println("==========Using rrule with fixed random feedback weights==========")

   function times_pullback(ΔΩ)
      ∂W = @thunk(ΔΩ * X')
      ∂X = @thunk(B' * ΔΩ) # Use random feedback weight matrix B
      return (NoTangent(), ∂W, ∂X, NoTangent())
   end
   return y, times_pullback
end

Flux.trainable(a::DenseFA) = (a.W,)

model = Chain(DenseFA(50, 60), DenseFA(60, 40), DenseFA(40, 2))

# dummy data (Batchsize 64)
x = rand(Float32, 50, 64);
y = rand(Float32, 2, 64)

# compute gradient
loss = Flux.Losses.mse
opt = Descent(0.1)
ps = Flux.params(model);
gs = gradient(() -> loss(model(x), y), ps)

I could use advice on why the rrule is not being used and how I can make Flux use it.

1 Like

What you’re missing is qualifying (or importing) rrule. At present this defines a new function of that name, instead (as you did for show) you need to add methods to the existing ChainRulesCore.rrule.

julia> rrule
rrule (generic function with 1 method)

julia> function ChainRulesCore.rrule(::typeof(my_matmul), W::Matrix, X::Matrix, B::Matrix)
          y = my_matmul(W, X, B)
          ...
2 Likes

Thanks, I had completely overlooked that!
Now everything works as expected :smiley: