Hello everyone. For my work with PINNs, I am trying to use ReverseDiff to calculate the gradient of a loss function that itself contains a ForwardDiff gradient. It is easy enough to do so with explicitly-parametrized deconstructed Flux networks, but the constant deconstruction and reconstruction of the model makes this approach extremely slow, as well as type-unstable.

Naturally, I went for Lux. However, ReverseDiff over ForwardDiff doesn’t seem to be working properly. Take the following code as a MWE: it is easy enough to calculate the gradient with regards to the inputs x with FD. However, taking the RD gradient of `gradxNet`

with regards to the parameters does not work.

What makes me think this is a bug is that it works fine for destructured Lux models. In the code, you’ll see that `netR`

should be exactly the same as `net`

with an added reconstruction overhead and yet, using it for the calculations works (it’s just very slow).

```
using Lux, Random, ForwardDiff, ReverseDiff, ComponentArrays, Optimisers
#Setup Lux Network
model = Chain(Dense(2 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x0 = [1f0,1f0]
#Destructure the Net
ps = ps |> ComponentArray
pr, re = Optimisers.destructure(ps)
ps == re(pr) #true
ps == re(ps) #true
#Normal and Restructured models
net(x,p) = first(first(model(x, p, st)))
net(x0,ps) #Works
netR(x,p) = first(first(model(x, re(p), st)))
netR(x0,pr) #Works
#ForwardDiff gradients with regards to x
gradxNet(x,p) = ForwardDiff.gradient(x -> net(x,p),x)
gradxNet(x0,ps) #Works
gradxNetR(x,p) = ForwardDiff.gradient(x -> netR(x,p),x)
gradxNetR(x0,pr) #Works
#ReverseDiff gradients with regards to parameters
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),pr) #works
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),ps) #works
ReverseDiff.gradient(p -> first(gradxNet(x0,p)),ps) #does not work
```

The error message for the last line reads:

```
ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 2})
```

I would create an issue on github but it’s unclear which one of the pieces is failing here. Could @avikpal @mcabbott perhaps shed some light on this issue?