Efficiently computing Hessians of Neural Networks output with respect to inputs

Hello everyone,

I’m currently working on a problem that involves approximating a partial differential equation using neural networks. I’ve created a script that generates a set of neural networks, computes their outputs for a batch of samples, and calculates both the first and second-order partial derivatives for each sample with respect to their respective inputs.

The stripped-down version of my script is below.

I’m looking for ways to optimize the part of the script that computes the gradient and hessian. Currently, I’m iterating over each sample and calculating these values individually. While this works, I’m worried that it may not be the most efficient approach, especially when dealing with a large number of samples. I’d appreciate any advice on how to make these calculations more efficient.

Thank you in advance for your help!

using Flux
using Flux: params, hessian, gradient
using Random

# Set the seed
Random.seed!(123) 

# Define the constants used in the model
M = 2  # Number of models
BS = 100  # Batch size

# Define neural networks for each unknown function
model_F = [Chain(
    Dense(M, 64, tanh),
    Dense(64, 64, tanh),
    Dense(64, 1)
) |> f64 for _ in 1:M]
 
Y = rand(0.01:100, M, BS)  # Generate new samples

for id in 1:M
    model = model_F[id]
    output = model(Y)

    # Compute the first-order partial derivatives and second-order partial derivatives
    grads = zeros(M, BS)
    hess  = zeros(M, M, BS)
    for i in 1:size(Y, 2)
        y = Y[:, i]   
        g = gradient(x -> sum(model(x)), y)  
        grads[:, i] = g[1]   
        g2 = hessian(x -> sum(model(x)), y)
        hess[:, :, i] = g2 
    end 
end   

Have you read through the performance tips in the docs?

First, I’d recommend using a profiler, like @profview in vscode, to figure out what is taking the most amount of time. To do this, I suggest putting your code in a function (as suggested in the performance tips), and applying the profiler to the function call.