Currently the derivative of round
is 0. I would like it to be 1 instead and also that it works when broadcasted. So i wrote a custom my_round
and tried to write the new rule with adjoint and @scalar_rule
but it does not fully work in particular it never works with CUDA array. Is there any work around ? (for the record I’m trying to implement quantization aware training on simple dense network and miserably failing )
Assuming that you are using Zygote (still Flux’s default), then what you may be missing is that differentiates broadcasts uses ForwardDiff. This is not affected by ChainRules’s @scalar_rule
. Instead, you would need round(::Dual)
… something like this?
julia> using Zygote, ForwardDiff, JLArrays
julia> Zygote.gradient(x -> sum(round.(x ./ 10; digits=2)), randn(3)) # that's a bug
ERROR: MethodError: no method matching iterate(::Nothing)
julia> Zygote.gradient(x -> sum(round.(x ./ 10; digits=2)), jl(randn(3))) # with GPU array
ERROR: MethodError: no method matching round(::ForwardDiff.Dual{Nothing, Float64, 1}, ::RoundingMode{:Nearest})
julia> Base.round(x::ForwardDiff.Dual; kw...) = ForwardDiff.Dual(round(ForwardDiff.value(x); kw...), ForwardDiff.partials(x))
julia> Zygote.gradient(x -> sum(round.(x ./ 10; digits=2)), jl(randn(3)))
([0.1, 0.1, 0.1],)
It works fine even on gpu by adding the last line, thank you.