I am currently trying to implement some code using @scalar_rule for a custom gradient for a very important group of functions in Physics and Engineering, the spherical bessel functions, in this case particularly of order 0. Despite being a Holomorphic, and in fact an Entire function, I usually write this function with an exception:
j0(x::Number) = iszero(x) ? one(x) : sin(x)/x
This creates problems when trying to use automatic differentiation, as ChainRules.jl and Flux.jl do not function well on that exception point, despite a derivative existing and being continuous everywhere, including the origin. In the past, I have skirted around this issue by adding an eps to the numerator and denominator to avoid making an exception for 0, but I’d like to do this properly. I found what I thought was the solution for my problem in the macro @scalar_rule. I defined the derivative explicitly as a separate function and created the rule:
using ChainRules, Flux import ChainRules: @scalar_rule j0(x::Number) = iszero(x) ? one(x) : sin(x)/x dj0(x::Number) = iszero(x) ? zero(x) : (x*cos(x) - sin(x))/(x^2) @scalar_rule(j0(x), (dj0(x)))
For real inputs, this solved my problem. It performs exactly as I expected it to, evading the conditional clause I put in:
julia>gradient(j0, 0.0) (0.0,) julia>gradient(j0, 1.0) (-0.30116867893975674,)
However, in my research, I often have to calculate this function and its derivative for complex input. In that case, the documentation for @scalar_rule says that the function is presumed Holomorphic, which
j0 is, meaning I expected the results to hold with no issue. However, any attempt to calculate a derivative at a complex point results in an error:
julia> gradient(j0, 0.0im) ERROR: Output is complex, so the gradient is not defined. Stacktrace:  error(s::String) @ Base ./error.jl:35  sensitivity(y::ComplexF64) @ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:66  gradient(f::Function, args::ComplexF64) @ Zygote ~/.julia/packages/Zygote/hLzJT/src/compiler/interface.jl:97  top-level scope @ REPL:1  top-level scope @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
Any idea what exactly my mistake might have been? I have been cracking my head at this for some time.
For the record, my system is running Arch Linux, and my Julia version is 1.8.5. I also updated all my packages right before running the test above. Any help would be very appreciated.