Holomorphic function not being recognized as such by ChainRules.jl


I am currently trying to implement some code using @scalar_rule for a custom gradient for a very important group of functions in Physics and Engineering, the spherical bessel functions, in this case particularly of order 0. Despite being a Holomorphic, and in fact an Entire function, I usually write this function with an exception:

j0(x::Number) = iszero(x) ? one(x) : sin(x)/x 

This creates problems when trying to use automatic differentiation, as ChainRules.jl and Flux.jl do not function well on that exception point, despite a derivative existing and being continuous everywhere, including the origin. In the past, I have skirted around this issue by adding an eps to the numerator and denominator to avoid making an exception for 0, but I’d like to do this properly. I found what I thought was the solution for my problem in the macro @scalar_rule. I defined the derivative explicitly as a separate function and created the rule:

using ChainRules, Flux
import ChainRules: @scalar_rule
j0(x::Number) = iszero(x) ? one(x) : sin(x)/x 
dj0(x::Number) = iszero(x) ? zero(x) : (x*cos(x) - sin(x))/(x^2)
@scalar_rule(j0(x), (dj0(x)))

For real inputs, this solved my problem. It performs exactly as I expected it to, evading the conditional clause I put in:

julia>gradient(j0, 0.0)
julia>gradient(j0, 1.0)

However, in my research, I often have to calculate this function and its derivative for complex input. In that case, the documentation for @scalar_rule says that the function is presumed Holomorphic, which j0 is, meaning I expected the results to hold with no issue. However, any attempt to calculate a derivative at a complex point results in an error:

julia> gradient(j0, 0.0im)
ERROR: Output is complex, so the gradient is not defined.
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] sensitivity(y::ComplexF64)
   @ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:66
 [3] gradient(f::Function, args::ComplexF64)
   @ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:97
 [4] top-level scope
   @ REPL[18]:1
 [5] top-level scope
   @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52

Any idea what exactly my mistake might have been? I have been cracking my head at this for some time.
For the record, my system is running Arch Linux, and my Julia version is 1.8.5. I also updated all my packages right before running the test above. Any help would be very appreciated.

This is not the fault of ChainRules, but of Zygote.gradient (Zygote.jl is the actual library doing the work when you call Flux.gradient). Because gradient isn’t provided a seed, it has to calculate one itself based on the inputs. Zygote seems to be very conservative about what kinds of inputs it can work with here.

Instead of using gradient, you can use the lower level pullback function and provide your own seed.

y, back = pullback(j0, 0.0)

I’m not familiar with taking derivatives of holomorphic functions, but Complex numbers · ChainRules may be helpful here.

1 Like

This did indeed work and now it shows expected behavior. It is rather inconvenient not being able to use the gradient function, but I’ll make do.
Thank you very much for your help.