Greetings.
I am currently trying to implement some code using @scalar_rule for a custom gradient for a very important group of functions in Physics and Engineering, the spherical bessel functions, in this case particularly of order 0. Despite being a Holomorphic, and in fact an Entire function, I usually write this function with an exception:
j0(x::Number) = iszero(x) ? one(x) : sin(x)/x
This creates problems when trying to use automatic differentiation, as ChainRules.jl and Flux.jl do not function well on that exception point, despite a derivative existing and being continuous everywhere, including the origin. In the past, I have skirted around this issue by adding an eps to the numerator and denominator to avoid making an exception for 0, but I’d like to do this properly. I found what I thought was the solution for my problem in the macro @scalar_rule. I defined the derivative explicitly as a separate function and created the rule:
using ChainRules, Flux
import ChainRules: @scalar_rule
j0(x::Number) = iszero(x) ? one(x) : sin(x)/x
dj0(x::Number) = iszero(x) ? zero(x) : (x*cos(x) - sin(x))/(x^2)
@scalar_rule(j0(x), (dj0(x)))
For real inputs, this solved my problem. It performs exactly as I expected it to, evading the conditional clause I put in:
julia>gradient(j0, 0.0)
(0.0,)
julia>gradient(j0, 1.0)
(-0.30116867893975674,)
However, in my research, I often have to calculate this function and its derivative for complex input. In that case, the documentation for @scalar_rule says that the function is presumed Holomorphic, which j0
is, meaning I expected the results to hold with no issue. However, any attempt to calculate a derivative at a complex point results in an error:
julia> gradient(j0, 0.0im)
ERROR: Output is complex, so the gradient is not defined.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] sensitivity(y::ComplexF64)
@ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:66
[3] gradient(f::Function, args::ComplexF64)
@ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:97
[4] top-level scope
@ REPL[18]:1
[5] top-level scope
@ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
Any idea what exactly my mistake might have been? I have been cracking my head at this for some time.
For the record, my system is running Arch Linux, and my Julia version is 1.8.5. I also updated all my packages right before running the test above. Any help would be very appreciated.