Hello,
I am trying to implement a sort of ‘batched’ HVP. Here’s the setup: consider some collection of particle positions x_{i,\alpha,\beta,\gamma} where i = 1,\ldots, d is a coordinate index, and \alpha,\beta,\gamma index particle identity/system replicate. Let f be some scalar function of particle position which also depends on some parameters p_\gamma which looks like f(\vec x_{\alpha,\beta ,\gamma}; p_\gamma). I would like to compute the quantities (in index notation):
where \vec v_{\alpha,\beta,\gamma} are some vectors (one for each particle/replicate).
Here’s some example code where I successfully implemented the first:
using CUDA, Zygote
n_dims = 3; n_α = 100; n_β = 100; n_γ = 100;
function f_single(x::AbstractVector{T}, μ::AbstractVector{T}) where {T<:Number}
return exp(-sum(abs2,x - μ))
end
function f_batch(x::AbstractArray{T,4}, μ::AbstractArray{T,2}) where {T<:Number}
return exp.(-dropdims(sum(abs2,x .- reshape(μ, size(μ,1), 1,1, size(μ,2)),dims=1),dims=1))
end
x = CUDA.randn(n_dims, n_α, n_β, n_γ);
μs = CUDA.randn(n_dims, n_γ)
f_batch(x, μs)::AbstractArray{Float32,3}
f_single(view(x,:,1,1,1), view(μs,:,1))::Float32
Zygote.gradient(x -> sum(f_batch(x, μs)), x)[1]# CuArray{Float32, 4, CUDA.DeviceMemory} of size (n_dims, n_α, n_β, n_γ)
I can’t figure out how to successfully implement the HVP, though. It seems that all built-in methods essentially will compute something like:
which is not what I want. I would greatly appreciate any recommendations on how to proceed. For my more complicated use case, it’s important that I run this on the GPU and that it is efficient as I will be computing with large arrays in high dimensions, thousands of times. Additionally, for my full use case, f_batch
is typically more efficient than broadcasting f_single
, so it would be ideal to work with f_batch
if possible. I am open to Zygote, Enzyme, or any other package that works and is efficient.
I appreciate any advice anyone might offer. I’m really stuck! Thank you in advance!