Getting ChainRules to work with custom Dense and Chain regularization functions

I have encountered a few issues when trying to implement a regularizer function and its pullback in ChainRules to a mildly customized Flux Dense layer and Chain. For simplicity and testing purposes, I just wrote a basic L2-regularizer function for the weights of a standard Dense layer:

using ChainRulesCore
using Flux
using Random 

function weightregularization(nn::Dense)
    return sum((nn.weight).^2.0)


function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)  
    y = weightregularization(nn)
    project_w = ProjectTo(nn.weight)
    function weightregularization_pullback(ȳ)
        pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight),   bias=ZeroTangent(), σ= NoTangent())
        return NoTangent(), pullb
    return y, weightregularization_pullback

It seems to work only partially. Calling gradient works fine, but there is still something incorrect in the pullback definition since ChainRulesTestUtils crashes with the custom pullback. It seems to try to calculate the pullback of the σ field of a Dense struct, even though there is NoTangent() in the pullback definition. What if I have a custom layer that has mode fields (which should be treated as constants that define the model), how can I make sure that TestUtils do not try to evaluate their pullback?

nn = Dense(randn(1,2), randn(1), tanh)
gr = gradient(weightregularization, nn) # Works
test_rrule(weightregularization,nn) # Crashes with MethodError: no method matching zero(::typeof(tanh))

test_rrule also crashes with TypeError: in new, expected Vector{Float32}, got a value of type Vector{Float64}, if the Dense layer weights are Float32s instead of Float64s, although I have ProjectTo in the pullback function.

Finally, the following custom regularization and pullback for a Chain crash both the gradient and test_rrule:

function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
    a = 0.0
    for i in ch
        a = a + sum(i.weight.^2.0)
    return a


function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T})  where T<:Tuple{Vararg{Dense}}
    y = totalregularization(ch)
    function totalregularization_pullback(ȳ)
        totalpullback = [] 
        N = length(ch)
        for i = 1:N
            project_w = ProjectTo(nn.weight)
            push!(totalpullback, Tangent{Dense}(weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
        pullb = Tangent{Chain{T}}(totalpullback...)
        return NoTangent(), pullb
    return y, totalregularization_pullback

l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr = gradient(totalregularization, ch))  # Crashes with

MethodError: no method matching canonicalize(::Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, Tuple{Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}}})

test_rrule(totalregularization,ch) # Crashes with 

Got exception outside of a @test
  return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, Tuple{Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}

What am I doing wrong? I understood that a custom reverse rule for a regularization function of for a Chain variable needs to defined through a structural tangent. That is, the outermost type must of something like Tangent{Chain{T}}. Similarly, the tangents of the layers must have the type of Tangent{Dense{S}}.

There seems to be some discussions regarding similar issues of implementing custom rules for Dense and Chain, like , but I do not follow if it is currently straightforward to achieve.

That issue is about callable structs, which does not apply to your example. I would say that if the rrule works outside of test_rrule, this is likely a limitation in ChainRulesTestUtils and may be deserving of a GH issue.