Diagnosing rrule issues

I am using ChainRulesTestUtils to test an rrule for a function mulsafe I define below. The idea is that mulsafe(x,y) returns the product x .* y, but giving a strong zero if x == 0 (even if y is Inf, or NaN).

using ChainRulesCore, ChainRulesTestUtils, Test

mulsafe(x, y) = @. ifelse(iszero(x), zero(x * y), x * y)
function rrule(::typeof(mulsafe), x, y)
    function mulsafe_pullback(δ)
		dx = @thunk(δ .* y)
		dy = @thunk(@. ifelse(iszero(x), zero(δ * x), δ * x))
        return NoTangent(), dx, dy
    end
    return mulsafe(x, y), mulsafe_pullback
end

@testset "mulsafe" begin
    @test mulsafe(1, 0) == 0
    @test mulsafe(0, 1) == 0
    @test mulsafe(0, Inf) == 0
    @test mulsafe(0, NaN) == 0
    @test isnan(mulsafe(Inf, 0))
    @test isnan(mulsafe(NaN, 0))
    @test mulsafe(Inf, 1) == Inf
    @test isnan(mulsafe(NaN, 1))
    @test mulsafe(2:5, 1:4) == mulsafe.(2:5, 1:4) == (2:5) .* (1:4)

    test_rrule(mulsafe, 1, 0)
    test_rrule(mulsafe, 0, 1)
    test_rrule(mulsafe, 0, Inf)
    test_rrule(mulsafe, 0, NaN)
    test_rrule(mulsafe, Inf, 0)
    test_rrule(mulsafe, NaN, 0)
    test_rrule(mulsafe, Inf, 1)
    test_rrule(mulsafe, NaN, 1)
    test_rrule(mulsafe, 2:5, 1:4)
end

When I run this code, I get some test failures:

I think test_rrule does more than one test internally. But I cannot know like this exactly what is failing. Is there a way to figure out what’s wrong?

Here is another example I don’t understand:

using ChainRulesCore, ChainRulesTestUtils, Test

function mylog(x::Real)
    r = log(abs(x))
    if x > 0
        return r
    else
        return oftype(r, -Inf)
    end
end

function ChainRulesCore.rrule(::typeof(mylog), x)
    function mylog_pullback(δ)
        if x > 0
            dx = 1/x
        else
            dx = oftype(1/x, Inf)
        end
        return NoTangent(), dx
    end
    return mylog(x), mylog_pullback
end

test_rrule(mylog, 1.0)

gives

test_rrule: mylog on Float64: Test Failed at /home/cossio/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox(1.0, -3.910000000000372; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
 [2] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:297
 [3] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
 [4] foreach(::Function, ::Tuple{NoTangent, Float64}, ::Tuple{NoTangent, Float64}, ::Vararg{Tuple{NoTangent, Float64}, N} where N)
   @ Base ./abstractarray.jl:2142
 [5] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
 [6] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [7] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
Test Summary:                | Pass  Fail  Total
test_rrule: mylog on Float64 |    6     1      7
ERROR: LoadError: Some tests did not pass: 6 passed, 1 failed, 0 errored, 0 broken.
in expression starting at /home/cossio/work/julia/test_rrule_log.jl:24

I am pretty sure the derivative of log(x) is 1/x, so what is going on here?

I agree ChainRulesTestUtils is not fantastic at being able to tell you why you failed, and identify where.
It’s something we have worked on a bit, and want to work on more (but limitations of the Test stdlib get in the way).
We hope we can improve it some, it’s nontrivial.
It is good at letting you know everything is pretty much for sure good if it passes.

On thing that can help identify what and where is if you just run one test at a time,
e.g. in the REPL (maybe using TestEnv.jl)
Then your particular failure messages will not get lost in the output.

In your output you have scrolled down past where it is telling you want exactly is wrong.
There are a about hundred lines of that it is most of the following screen shot:

But honestly that is all a bit much.
As is often the case, (but I will admit CRTU is particularly bad for needed it)
looking at one test at a time in the REPL is more managable.
(Which is why I created TestEnv.jl)

So let’s do that:

julia> test_rrule(mulsafe, 1, 0)
test_rrule: mulsafe on Int64,Int64: Test Failed at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
  Expression: ad_cotangent isa NoTangent
   Evaluated: Thunk(var"#5#8"{NoTangent, Int64}(NoTangent(), 0)) isa NoTangent
Stacktrace:
 [1] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
 [2] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
 [3] foreach(::Function, ::Tuple{NoTangent, NoTangent, NoTangent}, ::Tuple{NoTangent, Thunk{var"#5#8"{NoTangent, Int64}}, Thunk{var"#6#9"{NoTangent, Int64}}}, ::Vararg{Any, N} where N)
   @ Base ./abstractarray.jl:2142
 [4] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
 [5] macro expansion
   @ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
test_rrule: mulsafe on Int64,Int64: Test Failed at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
  Expression: ad_cotangent isa NoTangent
   Evaluated: Thunk(var"#6#9"{NoTangent, Int64}(NoTangent(), 1)) isa NoTangent
Stacktrace:
 [1] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
 [2] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
 [3] foreach(::Function, ::Tuple{NoTangent, NoTangent, NoTangent}, ::Tuple{NoTangent, Thunk{var"#5#8"{NoTangent, Int64}}, Thunk{var"#6#9"{NoTangent, Int64}}}, ::Vararg{Any, N} where N)
   @ Base ./abstractarray.jl:2142
 [4] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
 [5] macro expansion
   @ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
Test Summary:                      | Pass  Fail  Total
test_rrule: mulsafe on Int64,Int64 |    5     2      7
ERROR: Some tests did not pass: 5 passed, 2 failed, 0 errored, 0 broken.

This is a lot but it is more managable.
Lets start from the beginning:

  Expression: ad_cotangent isa NoTangent
   Evaluated: Thunk(var"#5#8"{NoTangent, Int64}(NoTangent(), 0)) isa NoTangent

So is saying that it was expecting the cotangent that came out of the AD rule to be a NoTangent.
But instead out-came a thunked value.
(and you can check that if you go and look at where the error is coming from according to the stack trace)

This one tends to come up in practice faily often from using a input primal type that is not continous.
In particular Int.
There is an open issue to give a better error message for this special case.

There are a few problems with doing this:
Firstly, we can’t tell if a type like Int actually represents a continous quanity and it is just a special case of what is conceptually a float (etc) and so we should perturb it.
Or if it actually represents an an index like d in size(x, d), in which case we should not pertube it, and the correct tangent is NoTangent since it is not pertable.
We assume it is the later.
Secondly though, finite differencing would always perturb a integer to be a non-integer anyway.
(c.f. Should `to_vec(::Integer)` return an empty vector · Issue #188 · JuliaDiff/FiniteDifferences.jl · GitHub)

So while you could tell it that the cotangent is not NoTangent via
this leads to an error from FiniteDifferences.jl (it does not handle types changing well)

julia> test_rrule(mulsafe, 1⊢3.0, 0⊢4.0)
test_rrule: mulsafe on Int64,Int64: Error During Test at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
  Got exception outside of a @test
  InexactError: Int64(0.99)
  Stacktrace:
    [1] Int64
      @ ./float.jl:723 [inlined]
    [2] convert
      @ ./number.jl:7 [inlined]
    [3] setindex!
      @ ./array.jl:839 [inlined]
    [4] (::FiniteDifferences.var"#62#64"{Int64, ComposedFunction{ComposedFunction{ComposedFunction{typeof(first), typeof(FiniteDifferences.to_vec)}, FiniteDifferences.var"#79#80"{ChainRulesTestUtils.var"#fnew#42"{ChainRulesTestUtils.var"#call#52"{NamedTuple{(), Tuple{}}}, Tuple{typeof(mulsafe), Int64, Int64}, Tuple{Bool, Bool, Bool}}}}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}}, Vector{Int64}})(ε::Float64)
      @ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:18

Now, what you can do though is just test on floating point values.
Since those are inferred to be continous (rather than descrete which gives NoTangent()) and they are pertubable (which kinda is the same thing).
And this is fine, you don’t have a seperate dispatch for integers anyway, so you are not testing a different code path.
(I would be very interested in seeing any rrule that does do that)

julia> test_rrule(mulsafe, 1.0, 0.0)
Test Summary:                          | Pass  Total
test_rrule: mulsafe on Float64,Float64 |    9      9
Test.DefaultTestSet("test_rrule: mulsafe on Float64,Float64", Any[], 9, false, false)

So lets go and change all your tests over and see where we get up to:

# giant list of failure locations above
Test Summary:
mulsafe                                                                                                                                                                                      |   74    19      1     94
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    9                   9
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    9                   9
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    8     1             9
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    9     3            12
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    6     3             9
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    7     5            12
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    7     2             9
  test_rrule: mulsafe on Float64,Float64                                                                                                                                                     |    7     5            12
  test_rrule: mulsafe on StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}} |    3            1      4

Not great, but progress.
So let’s move on:

julia> test_rrule(mulsafe, 0.0, Inf)
test_rrule: mulsafe on Float64,Float64: Test Failed at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox(-Inf, NaN; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
 [2] test_approx(actual::AbstractThunk, expected::Any, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:29
 [3] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:297
 [4] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
 [5] foreach(::Function, ::Tuple{NoTangent, Float64, Float64}, ::Tuple{NoTangent, Thunk{var"#10#13"{Float64, Float64}}, Thunk{var"#11#14"{Float64, Float64}}}, ::Vararg{Any, N} where N)
   @ Base ./abstractarray.jl:2142
 [6] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
 [7] macro expansion
   @ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [8] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194

So what is happening here:
The rrule has returned -Inf but finite differencing came up with NaN.
This is probably because at the end of the day finite differencing basically runs a fancy version of
\dfrac{f(x2) - f(x1)}{x2-x2} for x2 some point very close to x1
But if your input point is Inf then that boils down to Inf/Inf == NaN.
FiniteDifferencing can’t really work well around nonfinite points.
I have openned an issue to just immedately error Error if given nonfinite primal? · Issue #192 · JuliaDiff/FiniteDifferences.jl · GitHub

For this case you are probably better off just testing directly by invoking the rrule.
I suspect most of the other ones that are around nonfinite values are like that.
so moving on to the end

test_rrule(mulsafe, 2.0:5.0, 1.0:4.0)
julia> test_rrule(mulsafe, 2.0:5.0, 1.0:4.0)
test_rrule: mulsafe on StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}: Error During Test at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
  Got exception outside of a @test
  TypeError: in new, expected Int64, got a value of type Float64
  Stacktrace:
    [1] macro expansion
      @ ~/.julia/packages/FiniteDifferences/W3rQO/src/to_vec.jl:0 [inlined]
    [2] _force_construct
      @ ~/.julia/packages/FiniteDifferences/W3rQO/src/to_vec.jl:27 [inlined]
    [3] (::FiniteDifferences.var"#structtype_from_vec#29"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}, FiniteDifferences.var"#Tuple_from_vec#48"{NTuple{4, Int64}, NTuple{4, Int64}, NTuple{4, typeof(identity)}}, Tuple{FiniteDifferences.var"#structtype_from_vec#29"{Base.TwicePrecision{Float64}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}, FiniteDifferences.var"#structtype_from_vec#29"{Base.TwicePrecision{Float64}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}, FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}})(v::Vector{Float64})
# huge stacktrace

Urg, this seems like it is something going wrong in FiniteDifferences.jl
I know this because the stacktrace is talking about to_vec.
So you can be sure that it’s not your fault.
I have openned an issue

For now you can work around by using a different vector type, like:


julia> test_rrule(mulsafe, collect(2.0:5.0), collect(1.0:4.0))
Test Summary:                                          | Pass  Total
test_rrule: mulsafe on Vector{Float64},Vector{Float64} |    9      9
Test.DefaultTestSet("test_rrule: mulsafe on Vector{Float64},Vector{Float64}", Any[], 9, false, false)

Anyway, as you can see CRTU is not without it’s rough edges.
It is really hard to make a testing library for AD that meets all 3 of:

  • don’t let anything wrong through (I think we are doing well at that)
  • lets everything correct throguh (I think we are doing OK at that, but still some work as you can see)
  • Gives clear an helpful messages (we are trying, but mostly failing I will admit. This is the improved version)

Hopefully me working though how I would debug these failures helps you to.

1 Like

Thanks! It’s very helpful.

Just one question, when you say:

What do you mean? How can I "invoke the rrule" directly?

This needs to be dx * δ, then I think it should work (at least for positive x).

One possibility here would be for the tester to try evaluating the function at a non-integer point. size(x, nextfloat(float(d))) will be an error, from which you infer that d::Int represents a non-perturbable parameter.

(But not sure it’s worth the effort & complication.)

2 Likes

Right, thanks! My mistake.

BTW: thank you for raising these things.
This kinda things help us workout what need to be improved

1 Like

Thank you! Just a quick question, what did you mean by "invoking the rrule" directly above? For the Inf or NaN cases.

I mean like writing the code that does:

julia> y, pb = rrule(mulsafe, 0, Inf)
(0.0, var"#mulsafe_pullback#8"{Int64, Float64}(0, Inf))

julia> da, db = pb(1.0)
(NoTangent(), Thunk(var"#6#9"{Float64, Float64}(1.0, Inf)), Thunk(var"#7#10"{Float64, Int64}(1.0, 0)))

julia> y, pb = rrule(mulsafe, 0, Inf)
(0.0, var"#mulsafe_pullback#8"{Int64, Float64}(0, Inf))

julia> df, da, db = pb(1.0)
(NoTangent(), Thunk(var"#6#9"{Float64, Float64}(1.0, Inf)), Thunk(var"#7#10"{Float64, Int64}(1.0, 0)))

julia> @test y == 0.0
Test Passed

julia> @test df == NoTangent()
Test Passed

julia> @test da == Inf
Test Passed

julia> @test db == 0.0
Test Passed

Or what ever it is that is worthy of testing.

Or a bit shorter to use CRTU’s helper

julia> tangents = pb(1.0)
(NoTangent(), Thunk(var"#6#9"{Float64, Float64}(1.0, Inf)), Thunk(var"#7#10"{Float64, Int64}(1.0, 0)))

julia> ChainRulesTestUtils.test_approx(tangents, (NoTangent(), Inf, 0.0))
1 Like