ReverseDiff over ForwardDiff behaves strangely with Lux networks (Issue)

Hello everyone. For my work with PINNs, I am trying to use ReverseDiff to calculate the gradient of a loss function that itself contains a ForwardDiff gradient. It is easy enough to do so with explicitly-parametrized deconstructed Flux networks, but the constant deconstruction and reconstruction of the model makes this approach extremely slow, as well as type-unstable.

Naturally, I went for Lux. However, ReverseDiff over ForwardDiff doesn’t seem to be working properly. Take the following code as a MWE: it is easy enough to calculate the gradient with regards to the inputs x with FD. However, taking the RD gradient of gradxNet with regards to the parameters does not work.

What makes me think this is a bug is that it works fine for destructured Lux models. In the code, you’ll see that netR should be exactly the same as net with an added reconstruction overhead and yet, using it for the calculations works (it’s just very slow).

using Lux, Random, ForwardDiff, ReverseDiff, ComponentArrays, Optimisers

#Setup Lux Network
model = Chain(Dense(2 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x0 = [1f0,1f0]

#Destructure the Net
ps = ps |> ComponentArray 
pr, re = Optimisers.destructure(ps)

ps == re(pr) #true
ps == re(ps) #true

#Normal and Restructured models
net(x,p) = first(first(model(x, p, st)))
net(x0,ps) #Works
netR(x,p) = first(first(model(x, re(p), st)))
netR(x0,pr) #Works

#ForwardDiff gradients with regards to x
gradxNet(x,p) = ForwardDiff.gradient(x -> net(x,p),x)
gradxNet(x0,ps) #Works 
gradxNetR(x,p) = ForwardDiff.gradient(x -> netR(x,p),x) 
gradxNetR(x0,pr) #Works

#ReverseDiff gradients with regards to parameters
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),pr) #works
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),ps) #works
ReverseDiff.gradient(p -> first(gradxNet(x0,p)),ps)  #does not work

The error message for the last line reads:

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 2})

I would create an issue on github but it’s unclear which one of the pieces is failing here. Could @avikpal @mcabbott perhaps shed some light on this issue?

An update: As far as I can tell, nested differentiation with ReverseDiff does not play nicely with Lux parameters, maybe because they’re structured as a NamedTuple. Both Reverse-over-Forward and Forward-over-Reverse fail, and Reverse-over-Reverse returns a null gradient.

using Lux, Random, ComponentArrays
import ForwardDiff, ReverseDiff, AbstractDifferentiation as AD

#Network Setup 
model = Chain(Dense(1 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray
x0 = [1f0]

#Differentiation Setup
fw = AD.ForwardDiffBackend()
rv = AD.ReverseDiffBackend()

#Gradients with regards to the input x
net(x,p) = first(first(model(x, p, st)))
grad_xFW(x,p) = first(AD.gradient(fw, x -> net(x,p),x))
grad_xRV(x,p) = first(AD.gradient(rv, x -> net(x,p),x))

#Gradients with regards to the parameters p
AD.gradient(fw, p -> first(grad_xFW(x0,p)),ps) #Works, but is slow
AD.gradient(rv, p -> first(grad_xFW(x0,p)),ps) #Does not work
AD.gradient(fw, p -> first(grad_xRV(x0,p)),ps) #Does not work
AD.gradient(rv, p -> first(grad_xRV(x0,p)),ps) #Returns all zeroes

The error messages read:

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 1})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
  [1] convert(#unused#::Type{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 1})
...

and

ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#33#34", Float32}, Float32, 10})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
  [1] convert(#unused#::Type{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#33#34", Float32}, Float32, 10})
...

The problem seems to be with ComponentArrays, i.e., everything works fine when flattening the parameters into a “real” vector and restructuring:

ps, re = Flux.destructure(ps)

net(x, p) = first(first(model(x, re(p), st)))

Somehow, a ComponentArray works fine for a single gradient, but fails when nested …

1 Like

As it turns out, it’s ReverseDiff that is unable to cope with nested differentiation of closures. In fact, it’s easy to reproduce this bug in a much simpler context that has nothing to do with Lux or ComponentArrays:

import ForwardDiff, ReverseDiff, AbstractDifferentiation as AD

n = 1
x0 = Array(rand(n))
M0 = rand(n,n)

function proto(x,M)
    M*x |> sum
end

fw = AD.ForwardDiffBackend()
rv = AD.ReverseDiffBackend()

#Grads with regards to x
grad_x_FW(x,M) = AD.gradient(fw, x -> proto(x,M),x) |> first |> first
grad_x_RV(x,M) = AD.gradient(rv, x -> proto(x,M),x) |> first |> first


AD.gradient(fw, m -> grad_x_FW(x0,m),M0) #Forward-over-forward, correct
AD.gradient(rv, m -> grad_x_FW(x0,m),M0) #Reverse-over-forward, ERROR
AD.gradient(fw, m -> grad_x_RV(x0,m),M0) #Forward-over-reverse, ERROR 
AD.gradient(rv, m -> grad_x_RV(x0,m),M0) #Reverse-over-reverse, wrong

The big mystery, in fact, is why using the destructured Lux and Flux models works to begin with. I’m absolutely stumped.