Hello everyone. For my work with PINNs, I am trying to use ReverseDiff to calculate the gradient of a loss function that itself contains a ForwardDiff gradient. It is easy enough to do so with explicitly-parametrized deconstructed Flux networks, but the constant deconstruction and reconstruction of the model makes this approach extremely slow, as well as type-unstable.
Naturally, I went for Lux. However, ReverseDiff over ForwardDiff doesn’t seem to be working properly. Take the following code as a MWE: it is easy enough to calculate the gradient with regards to the inputs x with FD. However, taking the RD gradient of gradxNet
with regards to the parameters does not work.
What makes me think this is a bug is that it works fine for destructured Lux models. In the code, you’ll see that netR
should be exactly the same as net
with an added reconstruction overhead and yet, using it for the calculations works (it’s just very slow).
using Lux, Random, ForwardDiff, ReverseDiff, ComponentArrays, Optimisers
#Setup Lux Network
model = Chain(Dense(2 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 12,tanh),
Dense(12 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x0 = [1f0,1f0]
#Destructure the Net
ps = ps |> ComponentArray
pr, re = Optimisers.destructure(ps)
ps == re(pr) #true
ps == re(ps) #true
#Normal and Restructured models
net(x,p) = first(first(model(x, p, st)))
net(x0,ps) #Works
netR(x,p) = first(first(model(x, re(p), st)))
netR(x0,pr) #Works
#ForwardDiff gradients with regards to x
gradxNet(x,p) = ForwardDiff.gradient(x -> net(x,p),x)
gradxNet(x0,ps) #Works
gradxNetR(x,p) = ForwardDiff.gradient(x -> netR(x,p),x)
gradxNetR(x0,pr) #Works
#ReverseDiff gradients with regards to parameters
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),pr) #works
ReverseDiff.gradient(p -> first(gradxNetR(x0,p)),ps) #works
ReverseDiff.gradient(p -> first(gradxNet(x0,p)),ps) #does not work
The error message for the last line reads:
ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#7#8"{ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:348, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_4 = ViewAxis(349:361, Axis(weight = ViewAxis(1:12, ShapedAxis((1, 12), NamedTuple())), bias = ViewAxis(13:13, ShapedAxis((1, 1), NamedTuple())))))}}}}}, Float32}, Float32, 2})
I would create an issue on github but it’s unclear which one of the pieces is failing here. Could @avikpal @mcabbott perhaps shed some light on this issue?