Difficulties writing a program that computes PDEs involving Laplacians with AD


I have been continuously running into an issue with the usage of AD to calculate the Helmholtz equation and the wave equation. My main area of research is acoustic signal processing, specifically concerning environmental audio and how embedding the physics of the problem into estimations can enable greater representative power.
I understand that the application of AD and symbolics in SciML are highly related to my objective. However, my use case requires calculating these PDEs at arbitrary points in the domain, not restricted to grids. On the CPU, these calculations are trivial to achieve, and I have already confirmed it is indeed simple to calculate, say, the Helmholtz equation for a given point based on AD. I hope I am being clear. As an example of what I am talking about, I wrote a little snippet using a function I know satisfies the Helmholtz equation, the 0-order spherical Bessel function:

using DelimitedFiles, CUDA, Flux, LinearAlgebra,ArrayAllocators, KernelFunctions, KernelAbstractions, CUDAKernels, Tullio, ChainRulesCore, ForwardDiff, Enzyme
using ChainRulesCore: @scalar_rule

function j0(x::T)::T where {T<:Number}
    if iszero(x)
        return one(x)
        return sin(x)/x

function dj0(x::T)::T where {T<:Number}
    if iszero(x)
        return zero(x)
        return (x*cos(x) - sin(x))/(x^2)

function d2j0(x::T)::T where {T<:Number}
    if iszero(x)
        return -one(x)/3
        return ((2-x^2)*sin(x)-2*x*cos(x))/(x^3)

@scalar_rule(dj0(x), d2j0(x))
@scalar_rule(j0(x), dj0(x))

x = rand(3)

hess = Flux.diaghessian(x->j0(norm(x)), x)

The above results in 0.0, as is expected, since it is a solution to the most basic Helmholtz equation ∇²h+h=0. I have made x a CUDA array and performed the same test, with the same result, and performed this operation on a Chain with, again, no errors (given a caveat). However, try as I might, I cannot seem to be able to write a version of this code that satisfies Flux, CUDA, et al’s requirements in order to actually put this to use in any useful way.
I want to be able to calculate these PDEs for batches of data written in the Flux pattern, as matrices, but the program will fail because diaghessian only accepts scalar functions. Dividing it into each line results in warnings that I shouldn’t use scalar indexes, but using GPUs is integral to be able to train the model in a feasible time frame.
Performing the same code above on a Flux chain won’t even work unless I index the output:

j = Chain(
    Dense(3, 5, tanh),
    Dense(5,1, tanh)
hess = Flux.diaghessian(x->j(x), x)[1]

ERROR: Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] sensitivity(y::Vector{ForwardDiff.Dual{Nothing, Float32, 3}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:67
  [3] gradient(f::Function, args::Vector{ForwardDiff.Dual{Nothing, Float32, 3}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
  [4] (::Zygote.var"#134#137"{Int64, Val{1}, var"#267#268", Tuple{Vector{Float32}}})(x::Vector{ForwardDiff.Dual{Nothing, Float32, 3}})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/grad.jl:260
  [5] forward_diag(f::Zygote.var"#134#137"{Int64, Val{1}, var"#267#268", Tuple{Vector{Float32}}}, x::Vector{Float32}, #unused#::Val{3})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/forward.jl:62
  [6] forward_diag(f::Function, x::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/forward.jl:75
  [7] #133
    @ ~/.julia/packages/Zygote/SuKWp/src/lib/grad.jl:260 [inlined]
  [8] ntuple
    @ ./ntuple.jl:19 [inlined]
  [9] diaghessian(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/grad.jl:257
 [10] top-level scope
    @ ~/Desktop/Main/lib/sound_field_kernels.jl:35

And then if I do the Hessian can be calculated, but I can’t throw it to the GPU because I’ll be doing scalar indexing, as well as making it pretty tough to make it work on batches of data. And it’s a function from Flux, so I do not understand what exactly I might be doing wrong for it to not work with this format. Laplacians are featured in a variety of PDEs, so I - in my admittedly limited knowledge on the subject of machine learning in Julia - was expecting the diaghessian function to behave in the same way.
I have tried researching a lot of materials around this - Tullio.jl helped me solve a lot of the issues I was having writing structs and functions that don’t mutate arrays, for instance - but calculating PDEs involving second derivatives seems to be a brick wall for me. I have independently looked up how to write GPU kernels(nothing to do with my professional research, only idle curiosity. It is a great feature of the language that I have been enjoying a lot), but I don’t think it should require writing my own kernel for something like this.

I am using Julia 1.9 on an Arch Linux machine. Any help would be appreciated.