Problems defining rrule for Flux layer

Hello,

I am having trouble implementing the rrule for a Flux layer. I want to call Chain() with other Flux defined layers and then @train to update the parameters. The layer I’m defining is different, but I post a simple example with the same issue.
The problem is that it returns nothing rather than the actual gradient when calling Grads(…) with the parameters.

using Flux, ForwardDiff, Zygote
using Flux: @functor
import ChainRules
using ChainRules: NoTangent

struct SampleLayer{TW, TM, Tσ}
   weight::TW
   m::TM #some object we use but don't differentiate
   σ::Tσ
end

function SampleLayer(out, in)
   σ(x) = x .^ 2 #some non-linearity
   return SampleLayer(randn(in,out), "model", σ)
end

@functor SampleLayer

function (y::SampleLayer)(x)
   E = y.σ(y.weight * x)
   return E
end

function ChainRules.rrule(y::SampleLayer, x)
   E = y.weight * x
   function adj(dp)
      _ , gσ = Flux.pullback(y.σ, E)
      grad = [dp[i] * gσ(x)[1][i] for i in 1:length(dp)] #trick to multiply Fill object
      @show grad
      return (grad, NoTangent())
   end
   return y.σ(E), adj
end

sl = SampleLayer(1, 3)

p = params(sl)
gs = gradient(() -> sum(sl([3])), p)
display(gs[p[1]])

# grad = [0.4950147520312078, -9.274855711982319, 24.52348537535905]
# nothing

Interestingly the gradient is returned, but in a different place. What am I doing wrong?

display(gs.grads)
# IdDict{Any, Any} with 2 entries:
#   [0.0275008; -0.51527; 1.36242] => nothing
#   :(Main.sl)                     => [0.495015, -9.27486, 24.5235]

@cortner

Somehow this returns the expected result. Does anyone know why this works? What’s the difference?

using Flux, ForwardDiff, Zygote
using Flux: @functor
import ChainRules
using ChainRules: NoTangent

struct SampleLayer{TW, TM, Tσ}
   weight::TW
   m::TM #some object we use but don't differentiate
   σ::Tσ
end

function SampleLayer(out, in)
   σ(x) = x .^ 2 #some non-linearity
   return SampleLayer(randn(in,out), "model", σ)
end

@functor SampleLayer

(y::SampleLayer)(x) = _eval_sample_layer(y.σ, y.weight, x)

_eval_sample_layer(σ, W, x) = σ.( W * x )


function ChainRules.rrule(::typeof(_eval_sample_layer), σ, W, x)
   E = W * x

   function adj(dp)
      _ , gσ = Flux.pullback(σ, E)
      grad = [dp[i] * gσ(x)[1][i] for i in 1:length(dp)] #trick to multiply Fill object
      @show grad
      return (NoTangent(), NoTangent(), grad, NoTangent())
   end

   return σ(E), adj
end


sl = SampleLayer(1, 3)

p = params(sl)
gs = Flux.gradient(() -> sum(sl([3])), p)
display(gs[p[1]])

#3-element Vector{Float64}:
#   9.986785761065718
# -10.171464214539265
#  12.292822799338161

display(gs.grads)

# IdDict{Any, Any} with 1 entry:
#  [0.554821; -0.565081; 0.6… => [9.98679, -10.1715, 12.2928]