Fast Hessian and Gradient for PINNS using Enzyme/Zygote

I am working on a problem that involves solving partial differential equations using neural networks.
The current bottleneck in my code is computing the gradient and Hessian of batched data. My current implementation uses Zygote, but it seems highly inefficient given that I compute the Hessian on data point/observation at a time.
I welcome all advice, including a re-write using Enzyme.

A minimal implementation.


using Flux, Random
using Flux: params, gradient, hessian

# Set the seed
Random.seed!(123) 

# Define the constants used in the model
M  = 1    # Number of models/ State variables
BS = 512  # Batch size

# Define neural networks for each unknown function
model_F = [Chain(
   Dense(M, 64, tanh),
   Dense(64, 64, tanh),
   Dense(64, 1)  
) |> f64 for _ in 1:M]

Y = rand(0.01:100, M, BS)  # Generate new samples

function get_HessGrad(M, BS, model, Y) 
   grads = zeros(M, BS)
   hess  = zeros(M, M, BS)      
   # Compute the first-order partial derivatives and second-order partial derivatives
   grads[:, :] = gradient(x -> sum(model(x)), Y)[1]            
   for i in 1:BS
       y = @views Y[:, i] 
       hess[:, :, i] = hessian(x -> sum(model(x)), y) 
   end  
   return grads, hess
end 

grads, hess = get_HessGrad(M, BS, model_F[1], Y)