I’m planning on training some neural networks using ReverseDiff.jl (which I think makes sense because I’ll be training lots of small networks on the CPU), but I have a basic question about how ReverseDiff is intended to be used.
In the examples we’re given a function with some parameters and shown how to differentiate with respect to all parameters. But the outputs of a neural network are a function of (1) the weights and biases, and (2) the inputs to the network. I only need to differentiate with respect to the weights and biases, not the inputs. So is there a way with ReverseDiff to differentiate with respect to some of the parameters to a function, but not others?
Make use of a closure
to create a function that takes as input only the parameters you want to differentiate with respect to.
That does seem to be the intended solution, but it seems from the discussion at machine learning - Julia ReverseDiff: how to take a gradient w.r.t. only a subset of inputs? - Stack Overflow that pre-compiled tapes are not supported if you do this, which is really unfortunate.
Note that usually to compute gradient w.r.t. any of the inputs you need to go through the whole computational graph anyway. Ignoring other inputs thus will have little effect on the total run time in most cases.
Re precompiled tapes, is there a reason you are interested specifically in ReverseDiff?
It looks like ignoring the gradients with respect to the inputs might be the way to go then.
I have no huge reason for preferring ReverseDiff, it just seemed like an established package that would be suitable for training lots of small networks on the CPU. Is there another option that would make my life easier?
If you already have experience with ReverseDiff and there are no more blockers, then there’s no real reason to switch. If you encounter more issues though, you may explore other AD packages such as Zygote (perhaps, the most established package at the moment) or Yota (which shares the idea of a tape with ReverseDiff).
Thank you - Zygote seems amazingly more convenient than BackwardDiff, and I was able to get up and running with it very quickly. I don’t know how it will compare speed-wise, but worrying about pre-compiled tapes and such was probably premature optimisation anyway - I’ll stick with Zygote for now and see how it goes.