Zygote dropgrad for all function numerical arguments that are non-differentiable (e.g., Int)

Is there a way I can tell Zygote to automatically apply Zygote.dropgrad() for all function numerical arguments that are non-differentiable?


julia> import Zygote

julia> # naiively defining my function, and using Zygote to calculate the gradient
       f(n::Integer,x::Real) = sin(n*x)
f (generic function with 1 method)

julia> Zygote.gradient(f, 100, π)
(3.141592653589793, 100.0)

Zygote performs automatic differentiation with respect to n, although n is an integer and f can’t be differentiated with respect to n.

I hope Zygote can apply Zygote.dropgrad() to all arguments that are Integers, like n:

julia> f(n::Integer,x::Real) = sin(Zygote.dropgrad(n) * x)
f (generic function with 1 method)

julia> Zygote.gradient(f, 100, π)
(nothing, 100.0)

It would not be crazy to regard all integers as categorical and only floats as differentiable, but most of Julia isn’t fussy about such things, and I think Zygote just follows that & promotes:

julia> sin(1) == sin(1.0)

julia> gradient(sin, 1) == gradient(sin, 1.0)  # this seems very Julian

julia> sin(ForwardDiff.Dual(1,true))

julia> gradient(x -> abs(sin(x + 0*im)), 1)  # this we should fix
(0.5403023058681398 + 0.0im,)

Changing all integers would, for instance, break almost every example in Zygote’s readme, which I think is evidence that it would be surprising.

But I think you’re proposing only that a function whose signature specifies ::Integer or similar should give no gradient. That would probably be a good idea, someone would just need to figure out how to implement it.

julia> f(n::Integer,x::Real) = sin(n*x);  # has no specific rule

julia> gradient(f, 7, 0)
(0.0, 7.0)      # now
(nothing, 7.0)  # proposed

julia> f(ForwardDiff.Dual(7,true), 0)  # disallowed here
ERROR: MethodError: no method matching f(::ForwardDiff.Dual{Nothing, Int64, 1}, ::Int64)

julia> gradient(getindex, [5,4,3], 2)  # this has a rule
([0, 1, 0], nothing)