how efficient is the current Julia AD and optimization landscape in comparsion to let’s say PyTorch?
For example in PyTorch I can do
arr1 = arr1.requires_grad_(True) arr2 = arr2.requires_grad_(True) optimizer = torch.optim.Adam([arr1, arr2]) loss = f(arr1, arr2) l.backward() # only with respect to arr1 arr1 = arr1.requires_grad_(True) arr2 = arr1.requires_grad_(False) optimizer = torch.optim.Adam([arr1]) loss = f(arr1, arr2) l.backward()
I guess, PyTorch stores the computational graph for
l.backward() avoids unnecessary computation as in the second part.
But Optim.jl only accepts one array, so I had to do:
# full arr = ComponentArray(arr1=arr1, arr2=arr2) loss = f(arr) gradient(f, arr) ## put it in Optim # only with respect to arr1 # manual closure to avoid that arr2 is a optimizeable argument f2(arr1) = f(arr1, arr2) gradient(f2, arr1) ## put it in optim
But does Zygote avoid the unnecessary AD computations for
arr2? Also, how do I easily exclude some arrays from optimization as in PyTorch? Is there a native way in Julia? Optim.jl always expects one array, so I feel like it does not work? To provide a easy construct as in PyTorch is not possible?
This is more like an open question but it always hinders me to start developing a toolbox for my field since I don’t know how to solve this challenge nicely. Maybe I can learn from reading some Lux.jl source code? I’m especially worried about expensive unnecessary calculations in the gradient pass.
Any hints would be appreciated