I’m trying to calculate gradients through a chain of NN applications where I have some other differentiable operations in between, but Flux fails with the error in the title. I have made a minimal example to show the problem I have, showing the first iteration of the chain plus the first steps of the second iteration:
using Flux
using Flux.Tracker
net = Dense(4, 1)
s = rand(Float32, 4)
a = net(s)[1]
s1, s2, s3, s4 = s
ns = [s1 + a, s2, s3, s4]
s = ns
a = net(s)[1] # error
It seems that the problem has something to do with the (s1 + a) entry in ns, but I don’t see why it’s an issue? I need gradients back through a, and back through s1 in subsequent calls (not in the first one shown here, though), with respect to the NN parameters.