Laplacian of shallow Flux neural network with complex weights

Hi all,

I am trying to use Julia’s AD ecosystem to calculate what the title says - a Laplacian (trace of Hessian matrix) of a shallow (1 layer) Flux neural net.

There are two twists to the story:

  • I need all of the weights to be complex because of domain-specific reasons. My inputs are still real though.
  • The “activation” function actually comes from SpecialFunctions.jl, which seems to block AD with ForwardDiff.jl (I raised an issue on their GitHub)

Here is a simplified version of the model with auxiliary functions:

using Flux
using SpecialFunctions
using Zygote
using ForwardDiff

log_i0(z) = log(besselix(0,z)) .+ abs(real(z))

complex_init(dims...) = convert.(ComplexF64, Flux.glorot_normal(dims)) .* exp.(2π*im .*rand(Float64, dims...))

model = Chain(
    complex,
    Conv((3,3), 1=>4, log_i0; init=complex_init),
    sum
)

inputs = randn(8, 8 ,1, 1)
model(inputs)

First derivatives w.r.t inputs I can easily get through Zygote, no issues there:

Zygote.gradient(real∘model, complex(inputs)) |> first |> conj

My question is - can I get trace of Hessian matrix (Laplacian) of such a model? Are there tricks or different packages I am not aware of?

Something like Zygote.diaghessian looks perfect but uses ForwardDiff under the hood which has a problem with log_i0 and, even if I replace it with something less exotic, throws other errors.