# ReverseDiff over ForwardDiff behaves strangely with Lux networks (Issue)

Hello everyone. For my work with PINNs, I am trying to use ReverseDiff to calculate the gradient of a loss function that itself contains a ForwardDiff gradient. It is easy enough to do so with explicitly-parametrized deconstructed Flux networks, but the constant deconstruction and reconstruction of the model makes this approach extremely slow, as well as type-unstable.

Naturally, I went for Lux. However, ReverseDiff over ForwardDiff doesn’t seem to be working properly. Take the following code as a MWE: it is easy enough to calculate the gradient with regards to the inputs x with FD. However, taking the RD gradient of `gradxNet` with regards to the parameters does not work.

What makes me think this is a bug is that it works fine for destructured Lux models. In the code, you’ll see that `netR` should be exactly the same as `net` with an added reconstruction overhead and yet, using it for the calculations works (it’s just very slow).

``````using Lux, Random, ForwardDiff, ReverseDiff, ComponentArrays, Optimisers

#Setup Lux Network
model = Chain(Dense(2 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x0 = [1f0,1f0]

#Destructure the Net
ps = ps |> ComponentArray
pr, re = Optimisers.destructure(ps)

ps == re(pr) #true
ps == re(ps) #true

#Normal and Restructured models
net(x,p) = first(first(model(x, p, st)))
net(x0,ps) #Works
netR(x,p) = first(first(model(x, re(p), st)))
netR(x0,pr) #Works

#ForwardDiff gradients with regards to x

#ReverseDiff gradients with regards to parameters

``````

The error message for the last line reads:

``````ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 2})
``````

I would create an issue on github but it’s unclear which one of the pieces is failing here. Could @avikpal @mcabbott perhaps shed some light on this issue?

An update: As far as I can tell, nested differentiation with ReverseDiff does not play nicely with Lux parameters, maybe because they’re structured as a NamedTuple. Both Reverse-over-Forward and Forward-over-Reverse fail, and Reverse-over-Reverse returns a null gradient.

``````using Lux, Random, ComponentArrays
import ForwardDiff, ReverseDiff, AbstractDifferentiation as AD

#Network Setup
model = Chain(Dense(1 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray
x0 = [1f0]

#Differentiation Setup

#Gradients with regards to the input x
net(x,p) = first(first(model(x, p, st)))

#Gradients with regards to the parameters p
``````

``````ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 1})

Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
@ Base char.jl:50
...

Stacktrace:
 convert(#unused#::Type{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:12, ShapedAxis((12, 1), NamedTuple())), bias = ViewAxis(13:24, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(25:37, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 1})
...
``````

and

``````ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#33#34", Float32}, Float32, 10})

Closest candidates are:
(::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
@ Base rounding.jl:207
(::Type{T})(::T) where T<:Number
@ Core boot.jl:792
(::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
@ Base char.jl:50
...

Stacktrace:
 convert(#unused#::Type{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#33#34", Float32}, Float32, 10})
...
``````

The problem seems to be with `ComponentArrays`, i.e., everything works fine when flattening the parameters into a “real” vector and restructuring:

``````ps, re = Flux.destructure(ps)

net(x, p) = first(first(model(x, re(p), st)))
``````

Somehow, a `ComponentArray` works fine for a single gradient, but fails when nested …

1 Like

As it turns out, it’s ReverseDiff that is unable to cope with nested differentiation of closures. In fact, it’s easy to reproduce this bug in a much simpler context that has nothing to do with Lux or ComponentArrays:

``````import ForwardDiff, ReverseDiff, AbstractDifferentiation as AD

n = 1
x0 = Array(rand(n))
M0 = rand(n,n)

function proto(x,M)
M*x |> sum
end