Hello,
I am having trouble implementing the rrule for a Flux layer. I want to call Chain() with other Flux defined layers and then @train to update the parameters. The layer I’m defining is different, but I post a simple example with the same issue.
The problem is that it returns nothing rather than the actual gradient when calling Grads(…) with the parameters.
using Flux, ForwardDiff, Zygote
using Flux: @functor
import ChainRules
using ChainRules: NoTangent
struct SampleLayer{TW, TM, Tσ}
weight::TW
m::TM #some object we use but don't differentiate
σ::Tσ
end
function SampleLayer(out, in)
σ(x) = x .^ 2 #some non-linearity
return SampleLayer(randn(in,out), "model", σ)
end
@functor SampleLayer
function (y::SampleLayer)(x)
E = y.σ(y.weight * x)
return E
end
function ChainRules.rrule(y::SampleLayer, x)
E = y.weight * x
function adj(dp)
_ , gσ = Flux.pullback(y.σ, E)
grad = [dp[i] * gσ(x)[1][i] for i in 1:length(dp)] #trick to multiply Fill object
@show grad
return (grad, NoTangent())
end
return y.σ(E), adj
end
sl = SampleLayer(1, 3)
p = params(sl)
gs = gradient(() -> sum(sl([3])), p)
display(gs[p[1]])
# grad = [0.4950147520312078, -9.274855711982319, 24.52348537535905]
# nothing
Interestingly the gradient is returned, but in a different place. What am I doing wrong?
display(gs.grads)
# IdDict{Any, Any} with 2 entries:
# [0.0275008; -0.51527; 1.36242] => nothing
# :(Main.sl) => [0.495015, -9.27486, 24.5235]