Hi,
I am trying to use a custom rrule to implement Feedback Alignment learning (which is just like BP except that the error signals are propagated backwards using random feedback weights B, which is supposed to be more biologically plausible than using W both in the forwards and backwards pass).
I made a simplified layer DenseFA (without nonlinearity and bias for simplicity), along with a custom rrule. the rrule seems to work when called on its own, but when I use Flux gradient function it is no longer used.
Boilerplate: imports and layer definition
using Flux; using Flux: @functor, glorot_uniform
using ChainRulesCore; using ChainRulesCore: NoTangent, @thunk
# DenseFA struct
struct DenseFA{M1<:Matrix, M2<:Matrix}
W::M1 # Weights used in the forward pass
B::M2 # Weights used in the backwards pass
function DenseFA(W::M1, B::M2) where {M1<:Matrix, M2<:Matrix}
new{M1,M2}(W, B)
end
end
# Initialize DenseFA struct
function DenseFA(in::Integer, out::Integer; init = glorot_uniform)
W = init(out, in)
B = init(out, in)
return DenseFA(W, B)
end
@functor DenseFA
(a::DenseFA)(X) = my_matmul(a.W, X, a.B)
function my_matmul(W, X, B)
return W * X
end
function Base.show(io::IO, l::DenseFA)
print(io, "DenseFA(", size(l.W, 2), ", ", size(l.W, 1), ")")
end
The rrule and the gradient computation
function rrule(::typeof(my_matmul), W::Matrix, X::Matrix, B::Matrix)
y = my_matmul(W, X, B)
println("==========Using rrule with fixed random feedback weights==========")
function times_pullback(ΔΩ)
∂W = @thunk(ΔΩ * X')
∂X = @thunk(B' * ΔΩ) # Use random feedback weight matrix B
return (NoTangent(), ∂W, ∂X, NoTangent())
end
return y, times_pullback
end
Flux.trainable(a::DenseFA) = (a.W,)
model = Chain(DenseFA(50, 60), DenseFA(60, 40), DenseFA(40, 2))
# dummy data (Batchsize 64)
x = rand(Float32, 50, 64);
y = rand(Float32, 2, 64)
# compute gradient
loss = Flux.Losses.mse
opt = Descent(0.1)
ps = Flux.params(model);
gs = gradient(() -> loss(model(x), y), ps)
I could use advice on why the rrule is not being used and how I can make Flux use it.