I’m trying to calculate gradients through a chain of NN applications where I have some other differentiable operations in between, but Flux fails with the error in the title. I have made a minimal example to show the problem I have, showing the first iteration of the chain plus the first steps of the second iteration:
using Flux using Flux.Tracker net = Dense(4, 1) s = rand(Float32, 4) a = net(s) s1, s2, s3, s4 = s ns = [s1 + a, s2, s3, s4] s = ns a = net(s) # error
It seems that the problem has something to do with the (s1 + a) entry in ns, but I don’t see why it’s an issue? I need gradients back through a, and back through s1 in subsequent calls (not in the first one shown here, though), with respect to the NN parameters.