How do we compute the gradient and Laplacian of a neural network using GPU?

Hello everyone,

I managed to compute the gradient and Laplacian of a neural network in CPU and tried to calculate these quantities in GPU but could not figure out how to do it.

The example of the code that I am working with is given as follows:

using Lux
using LuxCUDA
using CUDA
using ComponentArrays
using Random
using Zygote
using ForwardDiff
using LinearAlgebra

const gpud = gpu_device()

##
# Define the input spatial coordinates r = (x, y, z) for phase-field points
function generate_grid(L, N)
    x = range(-L, L, length=N)
    y = range(-L, L, length=N)
    z = range(-L, L, length=N)
    
    # Create a 3×N³ array
    points = zeros(Float32, 3, N^3)
    idx = 1
    for zi in z, yi in y, xi in x
        points[:, idx] = Float32[xi, yi, zi]
        idx += 1
    end    
    return points
end

## Set up RNG
rng = Random.default_rng()
Random.seed!(rng, 0)

## Neural Network Definition
nn = Chain(
    Dense(3, 20, σ),    # Input layer: 3 inputs (x, y, z), 20 neurons, sigmoid activation
    Dense(20, 10, σ),   # Hidden layer: 10 neurons, sigmoid activation
    Dense(10, 1, tanh)  # Output layer: 1 output, tanh activation to get values in [-1, 1]
)

# Initialize parameters and move to GPU
parameters, layer_states = Lux.setup(rng, nn)

NN(x) = nn(x, parameters, layer_states)[1]
# NN(x) = x[1]^2 + x[2]^2 + x[3]^2

## Computational Domain
N = 2 # Number of points along each axis
L = 3.0f0  # Domain extends from -L to L in each dimension
sample_points = generate_grid(L, N)

# Test the phase field function
φ = NN(sample_points)

for kk in axes(sample_points, 2)
    r = sample_points[:, kk]
    φ = NN(r)[1]

    ∇φ = Zygote.gradient(s -> NN(s)[1], r)[1]

    ∇²φ = ForwardDiff.hessian(s -> NN(s)[1], r) |> diag |> sum
    println("r = $r")
    println("φ = $φ")
    println("∇φ = $∇φ")
    println("∇²φ = $∇²φ")
    println()
end

## GPU implementation ????
gpu_parameters = parameters |> ComponentArray |> gpud

gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]

gpu_sample_points = CuArray(sample_points)

gpu_φ = gpu_NN(gpu_sample_points)

∇φ = Zygote.gradient(s -> gpu_NN(s)[1], gpu_sample_points)[1]

∇²φ = ForwardDiff.hessian(s -> gpu_NN(s)[1], gpu_sample_points) |> diag |> sum

Any suggestions are greatly appreciated! Thank you!

Hi @aligurbu,
Do you mean the gradient and Hessian? At least that’s what your code would suggest.
Can you share which part of your code doesn’t work, and what is the error you encounter?
At first glance, I would suspect the issue is that ForwardDiff.hessian computes gradients using elementwise indexing, which is forbidden by default on GPU. Have you tried Zygote.hessian?

This isn’t an answer, just another implementation of generate_grid:

function generate_grid(L, N)
    rng = range(-L, L; length=N)
    return reshape(Float32[(x,y,z)[i] for i in 1:3, 
      x in rng, y in rng, z in rng],3,:)
end

It’s faster.

I want to compute the gradient and Laplacian of a neural network to define a loss function for the training, but since I could not figure out how to compute Laplacian directly, I compute it using the Hessian function as follows:

∇²φ = ForwardDiff.hessian(s -> NN(s)[1], r) |> diag |> sum

For example, if I try to compute the gradient and hessian as follows:

gpu_parameters = parameters |> ComponentArray |> gpud
gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]
gpu_sample_points = CuArray(sample_points)
∇φ = Zygote.gradient(s -> gpu_NN(s)[1], gpu_sample_points)[1]
∇²φ = ForwardDiff.hessian(s -> gpu_NN(s)[1], gpu_sample_points) |> diag |> sum
∇²φ = Zygote.hessian(s -> gpu_NN(s)[1], gpu_sample_points) |> diag |> sum

I have the following error for the gradient and hessian functions that you also alluded to:

ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.        
Such implementations *do not* execute on the GPU, but very slowly on the CPU,       
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.

If I allow scalar computation on GPU, I get an ArgumentError for the Zygote.gradient function

CUDA.allowscalar() do
           for kk in axes(gpu_sample_points, 2)
               r = gpu_sample_points[:, kk]
               φ = gpu_NN(r)[1]
               ∇φ = Zygote.gradient(s -> gpu_NN(s)[1], r)[1]
               # ∇²φ = ForwardDiff.hessian(s -> gpu_NN(s)[1], r) |> diag |> sum     
               println(∇²φ)
           end
       end
ERROR: ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.

Thanks a lot for sharing this implementation with me. It took me some time to wrap my mind around it to understand why it is better, but I think I got it. You eliminated the allocation by using tuples instead of arrays. Is that correct?

The ArgumentError is telling you that one of the inputs is on CPU (for eg, Array) and the other other is on GPU (for eg, CuArray). What is the type of r?

Yes, that’s what is puzzling me as well. The type of r seems like CuArray

julia> r = gpu_sample_points[:, 1]
       typeof(r)
CuArray{Float32, 1, CUDA.DeviceMemory}

Post the full stacktrace after the argumenterror

Thanks a lot for your help. Here it is the full Stacktrace:

julia> gpu_parameters = parameters |> ComponentArray |> gpud
ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:60, ShapedAxis((20, 3))), bias = 61:80)), layer_2 = ViewAxis(81:290, Axis(weight = ViewAxis(1:200, ShapedAxis((10, 20))), bias = 201:210)), layer_3 = ViewAxis(291:301, Axis(weight = ViewAxis(1:10, ShapedAxis((1, 10))), bias = 11:11)))}}}(layer_1 = (weight = Float32[-0.8827754 0.36517656 -0.56002903; -0.089523196 -0.80052364 0.35755897; … ; 0.22984266 -0.23869145 0.6892575; 0.3440013 -0.76533496 -0.15416396], bias = Float32[0.473517, 0.32531556, -0.56818783, 0.06394702, -0.25244415, -0.4316282, 0.026881456, -0.059109006, 0.14427646, -0.12220892, 0.32917956, 0.060895164, 0.14029668, -0.50013363, -0.5041214, -0.453094, 0.1960066, 0.0277223, 0.08420609, -0.25414172]), layer_2 = (weight = Float32[0.16217647 0.24069801 … 0.2096802 -0.30790493; 0.10382402 0.09449814 … -0.007781879 0.27407685; … ; -0.3674191 0.26801494 … 0.18675517 -0.09126645; -0.05721156 0.30545124 … -0.3172608 0.09927955], bias = Float32[-0.13602416, -0.18653981, -0.02721949, -0.13559724, -0.18748423, -0.19375737, 0.06723717, -0.15049928, 0.001352046, -0.08790343]), layer_3 = (weight = Float32[-0.042059317 -0.37042874 … 0.17988238 0.22771344], bias = Float32[-0.16884351]))

julia> gpu_NN(x) = nn(x, gpu_parameters, layer_states)[1]
gpu_NN (generic function with 1 method)

julia> gpu_sample_points = CuArray(sample_points)
3×27 CuArray{Float32, 2, CUDA.DeviceMemory}:
 -3.0   0.0   3.0  -3.0   0.0   3.0  …   3.0  -3.0  0.0  3.0  -3.0  0.0  3.0        
 -3.0  -3.0  -3.0   0.0   0.0   0.0     -3.0   0.0  0.0  0.0   3.0  3.0  3.0        
 -3.0  -3.0  -3.0  -3.0  -3.0  -3.0      3.0   3.0  3.0  3.0   3.0  3.0  3.0        

julia> CUDA.allowscalar() do
           for kk in axes(gpu_sample_points, 2)
               r = gpu_sample_points[:, kk]
               φ = gpu_NN(r)[1]
               ∇φ = Zygote.gradient(s -> gpu_NN(s)[1], r)[1]
           end
       end
ERROR: ArgumentError: Objects are on devices with different types: CPUDevice and CUDADevice.
Stacktrace:
  [1] combine_devices(T1::Type{CPUDevice}, T2::Type{CUDADevice})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:127
  [2] macro expansion
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:205 [inlined]
  [3] unrolled_mapreduce
    @ C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:192 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{…})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:183
  [5] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices.Internal C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\internal.jl:155
  [6] get_device_type(x::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ MLDataDevices C:\Users\aligu\.julia\packages\MLDataDevices\hoL1S\src\public.jl:388
  [7] internal_operation_mode(xs::Tuple{Base.ReshapedArray{…}, CuArray{…}})
    @ LuxLib C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\traits.jl:210
  [8] ∇activation(Δ::Base.ReshapedArray{…}, out::CuArray{…}, act::typeof(tanh_fast),
 x::LuxLib.Utils.NotaNumber)
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\activation.jl:107
  [9] (::LuxLib.Impl.var"#78#81"{…})(Δ::Base.ReshapedArray{…})
    @ LuxLib.Impl C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:51  
 [10] ZBack
    @ C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:212 [inlined]
 [11] fused_dense
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\impl\dense.jl:11 [inlined]    
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [13] fused_dense_bias_activation
    @ C:\Users\aligu\.julia\packages\LuxLib\wiiF1\src\api\dense.jl:35 [inlined]     
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Base.ReshapedArray{Float32, 2, ChainRules.OneElement{…}, Tuple{}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [15] Dense
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\basic.jl:343 [inlined]    
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [17] apply
    @ C:\Users\aligu\.julia\packages\LuxCore\SN4dl\src\LuxCore.jl:155 [inlined]     
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [19] applychain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:0 [inlined] 
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [21] Chain
    @ C:\Users\aligu\.julia\packages\Lux\gmUbf\src\layers\containers.jl:480 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{ChainRules.OneElement{…}, Nothing})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [23] gpu_NN
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:103 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRules.OneElement{Float32, 1, Tuple{…}, Tuple{…}})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [25] #44
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:109 [inlined]
 [26] (::Zygote.Pullback{Tuple{var"#44#46", CuArray{…}}, Tuple{Zygote.Pullback{…}, Zygote.Pullback{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [27] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:91
 [28] gradient(f::Function, args::CuArray{Float32, 1, CUDA.DeviceMemory})
    @ Zygote C:\Users\aligu\.julia\packages\Zygote\nyzjS\src\compiler\interface.jl:148
 [29] (::var"#43#45")()
    @ Main d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:109
 [30] task_local_storage(body::var"#43#45", key::Symbol, val::GPUArraysCore.ScalarIndexing)
    @ Base .\task.jl:297
 [31] allowscalar(f::Function)
    @ GPUArraysCore C:\Users\aligu\.julia\packages\GPUArraysCore\GMsgk\src\GPUArraysCore.jl:183
 [32] top-level scope
    @ d:\Repositories\LearningJulia\MySandBoxJulia\scripts\Lux_jl\Testing.jl:105    
Some type information was truncated. Use `show(err)` to see complete types.

Do you need the ExceptionStack? Please let me know if you need any other information.

this seems like a bug, can you open an issue with the full error (show(err)). I can patch it later today

I open an issue for the Lux.jl package.

Thanks a lot for your help.