A way of calculating only the Jacobians you want and the Hessians quickly, without scalar indexing, repeated calculations or mutations is to use the dual number implementations of ForwardDiff.jl
, maybe? The issue here is that the Zygote gradient still does not work for unrelated reasons to mutations, so if you need derivatives of the derivatives in terms of the parameters (say for something like a PINN) you are out of luck there.
using Zygote, ForwardDiff, Enzyme
using ForwardDiff: Dual, partials
C = 4
D = 8
A = 2
T = Float32
model(M, x) = tanh.(M*x)
M = randn(T, A, C)
v = randn(T, C, D)
Partial_indicators = [Tuple([n==m ? 1 : 0 for m=1:C]) for n=1:C]
dv = Dual{T}.(v, Partial_indicators)
d2v = Dual{T}.(dv, Partial_indicators)
grads = getindex.(reshape(partials.(model(M,dv)), 1, A,D), 1:C)#Gradients in terms of each output dimension, of shape C x A x D
hess = getindex.(reshape(partials.(getindex.(reshape(partials.(model(M,d2v)), 1, A, D), 1:C)), 1, C, A, D), 1:C)#Hessians, of shape C x C x A x D
However, I presume you want to train with them afterwards since you are using Flux, and in that case the fact these jacobians and Hessians are not differentiable with Zygote could be troublesome. Doing a for loop and calculating these values one by one is a non-starter, since there are mutations in the gradient function in Zygote itself, so you can’t differentiate a Zygote gradient with Zygote again. I’ve found that Enzyme still works, though:
function loss(M, d2v)
return sum(getindex.(reshape(partials.(getindex.(reshape(partials.(model(M, d2v)), 1, A, D), 1:C)), 1, C, A, D), 1:C))
end
Enzyme.gradient(Reverse, (M)->loss(M, d2v), M)
So if you are interested in training after, an idea could be to use ForwardDiff Duals to calculate the partial derivatives you are interested in and then doing the training with Enzyme as the AD backend.
EDIT: Also if you only want the diagonals of the Hessians, divide d2v into C different matrices where the only nonzero partial in the duals is the corresponding index of interest, maybe? Alternatively, you could avoid all of this hacky nonsense I cobbled together by using numerical derivatives instead of exact ones.