I have encountered a few issues when trying to implement a regularizer function and its pullback in ChainRules to a mildly customized Flux Dense layer and Chain. For simplicity and testing purposes, I just wrote a basic L2-regularizer function for the weights of a standard Dense layer:
using ChainRulesCore
using Flux
using Random
function weightregularization(nn::Dense)
return sum((nn.weight).^2.0)
end
function ChainRulesCore.rrule(::typeof(weightregularization), nn::Dense)
y = weightregularization(nn)
project_w = ProjectTo(nn.weight)
function weightregularization_pullback(ȳ)
pullb = Tangent{Dense}(weight=project_w(ȳ * 2.0*nn.weight), bias=ZeroTangent(), σ= NoTangent())
return NoTangent(), pullb
end
return y, weightregularization_pullback
end
It seems to work only partially. Calling gradient
works fine, but there is still something incorrect in the pullback definition since ChainRulesTestUtils crashes with the custom pullback. It seems to try to calculate the pullback of the σ field of a Dense struct, even though there is NoTangent(
) in the pullback definition. What if I have a custom layer that has mode fields (which should be treated as constants that define the model), how can I make sure that TestUtils do not try to evaluate their pullback?
nn = Dense(randn(1,2), randn(1), tanh)
gr = gradient(weightregularization, nn) # Works
test_rrule(weightregularization,nn) # Crashes with MethodError: no method matching zero(::typeof(tanh))
test_rrule
also crashes with TypeError: in new, expected Vector{Float32}, got a value of type Vector{Float64}
, if the Dense layer weights are Float32s instead of Float64s, although I have ProjectTo in the pullback function.
Finally, the following custom regularization and pullback for a Chain crash both the gradient
and test_rrule
:
function totalregularization(ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
a = 0.0
for i in ch
a = a + sum(i.weight.^2.0)
end
return a
end
function ChainRulesCore.rrule(::typeof(totalregularization), ch::Chain{T}) where T<:Tuple{Vararg{Dense}}
y = totalregularization(ch)
function totalregularization_pullback(ȳ)
totalpullback = []
N = length(ch)
for i = 1:N
project_w = ProjectTo(nn.weight)
push!(totalpullback, Tangent{Dense}(weight= project_w(ȳ * 2.0*ch[i].weight), bias = ZeroTangent(), σ= NoTangent()))
end
pullb = Tangent{Chain{T}}(totalpullback...)
return NoTangent(), pullb
end
return y, totalregularization_pullback
end
l1 = Dense(randn(2,2), randn(2), tanh)
l2 = Dense(randn(1,2), randn(1), tanh)
ch = Chain(l1,l2)
gr = gradient(totalregularization, ch)) # Crashes with
MethodError: no method matching canonicalize(::Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, Tuple{Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}}})
test_rrule(totalregularization,ch) # Crashes with
Got exception outside of a @test
return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}, Tuple{Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}, Tangent{Dense, NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, ZeroTangent, NoTangent}}}}}} does not match inferred return type Tuple{NoTangent, Tangent{Chain{Tuple{Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}, Dense{typeof(tanh), Matrix{Float64}, Vector{Float64}}}}}}
What am I doing wrong? I understood that a custom reverse rule for a regularization function of for a Chain variable needs to defined through a structural tangent. That is, the outermost type must of something like Tangent{Chain{T}}
. Similarly, the tangents of the layers must have the type of Tangent{Dense{S}}
.