Batched Hessian-Vector Product (on the GPU)

Hello,

I am trying to implement a sort of ‘batched’ HVP. Here’s the setup: consider some collection of particle positions x_{i,\alpha,\beta,\gamma} where i = 1,\ldots, d is a coordinate index, and \alpha,\beta,\gamma index particle identity/system replicate. Let f be some scalar function of particle position which also depends on some parameters p_\gamma which looks like f(\vec x_{\alpha,\beta ,\gamma}; p_\gamma). I would like to compute the quantities (in index notation):

\texttt{grad}_{i,\alpha,\beta,\gamma} = \frac{\partial}{\partial x_{i,\alpha,\beta,\gamma}} f(\vec x_{\alpha,\beta,\gamma}), \\ \texttt{hvp}_{i,\alpha,\beta,\gamma} = \sum_{j=1}^d v_{j,\alpha,\beta,\gamma} \frac{\partial}{\partial x_{i,\alpha,\beta,\gamma}} \frac{\partial}{\partial x_{j,\alpha,\beta,\gamma}} f(\vec x_{\alpha,\beta,\gamma}),

where \vec v_{\alpha,\beta,\gamma} are some vectors (one for each particle/replicate).
Here’s some example code where I successfully implemented the first:

using CUDA, Zygote
n_dims = 3; n_α = 100; n_β = 100; n_γ = 100;

function f_single(x::AbstractVector{T}, μ::AbstractVector{T}) where {T<:Number}
    return exp(-sum(abs2,x - μ))
end

function f_batch(x::AbstractArray{T,4}, μ::AbstractArray{T,2}) where {T<:Number}
    return exp.(-dropdims(sum(abs2,x .- reshape(μ, size(μ,1), 1,1, size(μ,2)),dims=1),dims=1))
end

x = CUDA.randn(n_dims, n_α, n_β, n_γ); 
μs = CUDA.randn(n_dims, n_γ)
f_batch(x, μs)::AbstractArray{Float32,3}
f_single(view(x,:,1,1,1), view(μs,:,1))::Float32

Zygote.gradient(x -> sum(f_batch(x, μs)), x)[1]# CuArray{Float32, 4, CUDA.DeviceMemory} of size (n_dims, n_α, n_β, n_γ)

I can’t figure out how to successfully implement the HVP, though. It seems that all built-in methods essentially will compute something like:

\sum_{j,\alpha',\beta',\gamma'} v_{j,\alpha',\beta',\gamma'} \frac{\partial}{\partial x_{i,\alpha,\beta,\gamma}} \frac{\partial}{\partial x_{j,\alpha',\beta',\gamma'}} f(\vec x_{\alpha',\beta',\gamma'})

which is not what I want. I would greatly appreciate any recommendations on how to proceed. For my more complicated use case, it’s important that I run this on the GPU and that it is efficient as I will be computing with large arrays in high dimensions, thousands of times. Additionally, for my full use case, f_batch is typically more efficient than broadcasting f_single, so it would be ideal to work with f_batch if possible. I am open to Zygote, Enzyme, or any other package that works and is efficient.

I appreciate any advice anyone might offer. I’m really stuck! Thank you in advance!