Hi all,
how efficient is the current Julia AD and optimization landscape in comparsion to let’s say PyTorch?
Python
For example in PyTorch I can do
arr1 = arr1.requires_grad_(True)
arr2 = arr2.requires_grad_(True)
optimizer = torch.optim.Adam([arr1, arr2])
loss = f(arr1, arr2)
l.backward()
# only with respect to arr1
arr1 = arr1.requires_grad_(True)
arr2 = arr1.requires_grad_(False)
optimizer = torch.optim.Adam([arr1])
loss = f(arr1, arr2)
l.backward()
I guess, PyTorch stores the computational graph for arr1
and arr2
and l.backward()
avoids unnecessary computation as in the second part.
Julia
How do I achieve a similar behavior in Julia?
I always used ComponentArrays, Zygote and Optim.jl for such problems.
But Optim.jl only accepts one array, so I had to do:
# full
arr = ComponentArray(arr1=arr1, arr2=arr2)
loss = f(arr)
gradient(f, arr)
## put it in Optim
# only with respect to arr1
# manual closure to avoid that arr2 is a optimizeable argument
f2(arr1) = f(arr1, arr2)
gradient(f2, arr1)
## put it in optim
But does Zygote avoid the unnecessary AD computations for arr2
? Also, how do I easily exclude some arrays from optimization as in PyTorch? Is there a native way in Julia? Optim.jl always expects one array, so I feel like it does not work? To provide a easy construct as in PyTorch is not possible?
This is more like an open question but it always hinders me to start developing a toolbox for my field since I don’t know how to solve this challenge nicely. Maybe I can learn from reading some Lux.jl source code? I’m especially worried about expensive unnecessary calculations in the gradient pass.
Any hints would be appreciated
Felix