Automatic differentiation for Complex output w.r.t to Real inputs taking too long compared to the function evaluation

I am using Zygote for differentiating a complex output with respect to 3 real inputs. The function evaluation itself takes about 0.088371 seconds (146.12 k allocations: 10.071 MiB, 99.23% compilation time)

but the gradient evaluation itself takes 8.398555 seconds (29.47 M allocations: 1.482 GiB, 11.13% gc time, 98.91% compilation time: 2% of which was recompilation). This is a very big difference


 function GF_xyz(source, image)
    function wave_func(source,image)
        dx = source[1] - image[1]
        dy = source[2] - image[2]
        vv = source[3] + image[3]
        dvv = source[3] - image[3]

        X = sqrt(dx^2 + dy^2 + dvv^2)
        Y = -vv

        result = complex(predict(params, [X, Y])[1]) + 2 * π * im * exp(-Y * BesselJ0(X))

        return result
    end

    gradient_source_imag = Zygote.jacobian(source -> imag(wave_func(source,image)), source)
    gradient_source_real = Zygote.jacobian(source -> real(wave_func(source,image)), source)
    return gradient_source_real .+ im .* gradient_source_imag
end

predict (NN : matrix-vector multiplications and activation function loop) and BesselJ0 defined manually elsewhere. (The BesselJ0 in Specialfunctions.jl did not allow zygote to compile as I kept getting segmentatino fault in Julia 1.9.4. )

I do not think its the predict function or BesselJ0 function. Its Zygote afaik as the function evaluation is fast enough.

My main motive here is to use the AD to differentiate the function I approximated. But from this, it seems I will have to approximates derivatives of the function which defeats the purpose of AD.

  1. is it the case that complex output jacobian is not made fast enough?

This percentage is not usually quite accurate, but notice that the timer is warning you that these runs were dominated by compilation time. If you ran your call again it would likely not require any compilation and would be much faster. See the performance tips. If it’s still showing significant compilation for repeated calls within a session, you’re probably doing too much testing in global scope. Also consider using BenchmarkTools.jl for more accurate benchmarks.

You might want to try Bessels.jl which has pure Julia bessel function implementations that are usually faster than the ones in SpecialFuncitons.jl.

For my use case, I ended up just using a simple implementation I have which is in pure Julia as well.