I am using ChainRulesTestUtils to test an rrule for a function mulsafe I define below. The idea is that mulsafe(x,y) returns the product x .* y, but giving a strong zero if x == 0 (even if y is Inf, or NaN).
using ChainRulesCore, ChainRulesTestUtils, Test
function mylog(x::Real)
r = log(abs(x))
if x > 0
return r
else
return oftype(r, -Inf)
end
end
function ChainRulesCore.rrule(::typeof(mylog), x)
function mylog_pullback(δ)
if x > 0
dx = 1/x
else
dx = oftype(1/x, Inf)
end
return NoTangent(), dx
end
return mylog(x), mylog_pullback
end
test_rrule(mylog, 1.0)
gives
test_rrule: mylog on Float64: Test Failed at /home/cossio/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox(1.0, -3.910000000000372; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
[1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/check_result.jl:24
[2] _test_cotangent(accum_cotangent::Any, ad_cotangent::Any, fd_cotangent::Any; check_inferred::Any, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:297
[3] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
[4] foreach(::Function, ::Tuple{NoTangent, Float64}, ::Tuple{NoTangent, Float64}, ::Vararg{Tuple{NoTangent, Float64}, N} where N)
@ Base ./abstractarray.jl:2142
[5] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
[6] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[7] test_rrule(config::RuleConfig, f::Any, args::Any; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
Test Summary: | Pass Fail Total
test_rrule: mylog on Float64 | 6 1 7
ERROR: LoadError: Some tests did not pass: 6 passed, 1 failed, 0 errored, 0 broken.
in expression starting at /home/cossio/work/julia/test_rrule_log.jl:24
I am pretty sure the derivative of log(x) is 1/x, so what is going on here?
I agree ChainRulesTestUtils is not fantastic at being able to tell you why you failed, and identify where.
It’s something we have worked on a bit, and want to work on more (but limitations of the Test stdlib get in the way).
We hope we can improve it some, it’s nontrivial.
It is good at letting you know everything is pretty much for sure good if it passes.
On thing that can help identify what and where is if you just run one test at a time,
e.g. in the REPL (maybe using TestEnv.jl)
Then your particular failure messages will not get lost in the output.
In your output you have scrolled down past where it is telling you want exactly is wrong.
There are a about hundred lines of that it is most of the following screen shot:
But honestly that is all a bit much.
As is often the case, (but I will admit CRTU is particularly bad for needed it)
looking at one test at a time in the REPL is more managable.
(Which is why I created TestEnv.jl)
So let’s do that:
julia> test_rrule(mulsafe, 1, 0)
test_rrule: mulsafe on Int64,Int64: Test Failed at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
Expression: ad_cotangent isa NoTangent
Evaluated: Thunk(var"#5#8"{NoTangent, Int64}(NoTangent(), 0)) isa NoTangent
Stacktrace:
[1] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
[2] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
[3] foreach(::Function, ::Tuple{NoTangent, NoTangent, NoTangent}, ::Tuple{NoTangent, Thunk{var"#5#8"{NoTangent, Int64}}, Thunk{var"#6#9"{NoTangent, Int64}}}, ::Vararg{Any, N} where N)
@ Base ./abstractarray.jl:2142
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
[5] macro expansion
@ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
test_rrule: mulsafe on Int64,Int64: Test Failed at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
Expression: ad_cotangent isa NoTangent
Evaluated: Thunk(var"#6#9"{NoTangent, Int64}(NoTangent(), 1)) isa NoTangent
Stacktrace:
[1] _test_cotangent(::NoTangent, ad_cotangent::Any, ::NoTangent; kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:303
[2] (::ChainRulesTestUtils.var"#49#53"{Bool, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})(::Any, ::Vararg{Any, N} where N)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:224
[3] foreach(::Function, ::Tuple{NoTangent, NoTangent, NoTangent}, ::Tuple{NoTangent, Thunk{var"#5#8"{NoTangent, Int64}}, Thunk{var"#6#9"{NoTangent, Int64}}}, ::Vararg{Any, N} where N)
@ Base ./abstractarray.jl:2142
[4] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:223 [inlined]
[5] macro expansion
@ /usr/local/src/julia/julia-1.6/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[6] test_rrule(::RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:194
Test Summary: | Pass Fail Total
test_rrule: mulsafe on Int64,Int64 | 5 2 7
ERROR: Some tests did not pass: 5 passed, 2 failed, 0 errored, 0 broken.
This is a lot but it is more managable.
Lets start from the beginning:
Expression: ad_cotangent isa NoTangent
Evaluated: Thunk(var"#5#8"{NoTangent, Int64}(NoTangent(), 0)) isa NoTangent
So is saying that it was expecting the cotangent that came out of the AD rule to be a NoTangent.
But instead out-came a thunked value.
(and you can check that if you go and look at where the error is coming from according to the stack trace)
This one tends to come up in practice faily often from using a input primal type that is not continous.
In particular Int.
There is an open issue to give a better error message for this special case.
There are a few problems with doing this:
Firstly, we can’t tell if a type like Int actually represents a continous quanity and it is just a special case of what is conceptually a float (etc) and so we should perturb it.
Or if it actually represents an an index like d in size(x, d), in which case we should not pertube it, and the correct tangent is NoTangent since it is not pertable.
We assume it is the later.
Secondly though, finite differencing would always perturb a integer to be a non-integer anyway.
(c.f. Should `to_vec(::Integer)` return an empty vector · Issue #188 · JuliaDiff/FiniteDifferences.jl · GitHub)
So while you could tell it that the cotangent is not NoTangent via ⊢
this leads to an error from FiniteDifferences.jl (it does not handle types changing well)
julia> test_rrule(mulsafe, 1⊢3.0, 0⊢4.0)
test_rrule: mulsafe on Int64,Int64: Error During Test at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
Got exception outside of a @test
InexactError: Int64(0.99)
Stacktrace:
[1] Int64
@ ./float.jl:723 [inlined]
[2] convert
@ ./number.jl:7 [inlined]
[3] setindex!
@ ./array.jl:839 [inlined]
[4] (::FiniteDifferences.var"#62#64"{Int64, ComposedFunction{ComposedFunction{ComposedFunction{typeof(first), typeof(FiniteDifferences.to_vec)}, FiniteDifferences.var"#79#80"{ChainRulesTestUtils.var"#fnew#42"{ChainRulesTestUtils.var"#call#52"{NamedTuple{(), Tuple{}}}, Tuple{typeof(mulsafe), Int64, Int64}, Tuple{Bool, Bool, Bool}}}}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}}, Vector{Int64}})(ε::Float64)
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:18
Now, what you can do though is just test on floating point values.
Since those are inferred to be continous (rather than descrete which gives NoTangent()) and they are pertubable (which kinda is the same thing).
And this is fine, you don’t have a seperate dispatch for integers anyway, so you are not testing a different code path.
(I would be very interested in seeing any rrule that does do that)
julia> test_rrule(mulsafe, 1.0, 0.0)
Test Summary: | Pass Total
test_rrule: mulsafe on Float64,Float64 | 9 9
Test.DefaultTestSet("test_rrule: mulsafe on Float64,Float64", Any[], 9, false, false)
So lets go and change all your tests over and see where we get up to:
# giant list of failure locations above
Test Summary:
mulsafe | 74 19 1 94
test_rrule: mulsafe on Float64,Float64 | 9 9
test_rrule: mulsafe on Float64,Float64 | 9 9
test_rrule: mulsafe on Float64,Float64 | 8 1 9
test_rrule: mulsafe on Float64,Float64 | 9 3 12
test_rrule: mulsafe on Float64,Float64 | 6 3 9
test_rrule: mulsafe on Float64,Float64 | 7 5 12
test_rrule: mulsafe on Float64,Float64 | 7 2 9
test_rrule: mulsafe on Float64,Float64 | 7 5 12
test_rrule: mulsafe on StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}} | 3 1 4
So what is happening here:
The rrule has returned -Inf but finite differencing came up with NaN.
This is probably because at the end of the day finite differencing basically runs a fancy version of \dfrac{f(x2) - f(x1)}{x2-x2} for x2 some point very close to x1
But if your input point is Inf then that boils down to Inf/Inf == NaN.
FiniteDifferencing can’t really work well around nonfinite points.
I have openned an issue to just immedately error Error if given nonfinite primal? · Issue #192 · JuliaDiff/FiniteDifferences.jl · GitHub
For this case you are probably better off just testing directly by invoking the rrule.
I suspect most of the other ones that are around nonfinite values are like that.
so moving on to the end
test_rrule(mulsafe, 2.0:5.0, 1.0:4.0)
julia> test_rrule(mulsafe, 2.0:5.0, 1.0:4.0)
test_rrule: mulsafe on StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}},StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}: Error During Test at /home/oxinabox/.julia/packages/ChainRulesTestUtils/f5cNH/src/testers.jl:191
Got exception outside of a @test
TypeError: in new, expected Int64, got a value of type Float64
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/FiniteDifferences/W3rQO/src/to_vec.jl:0 [inlined]
[2] _force_construct
@ ~/.julia/packages/FiniteDifferences/W3rQO/src/to_vec.jl:27 [inlined]
[3] (::FiniteDifferences.var"#structtype_from_vec#29"{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}}, FiniteDifferences.var"#Tuple_from_vec#48"{NTuple{4, Int64}, NTuple{4, Int64}, NTuple{4, typeof(identity)}}, Tuple{FiniteDifferences.var"#structtype_from_vec#29"{Base.TwicePrecision{Float64}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}, FiniteDifferences.var"#structtype_from_vec#29"{Base.TwicePrecision{Float64}, FiniteDifferences.var"#Tuple_from_vec#48"{Tuple{Int64, Int64}, Tuple{Int64, Int64}, Tuple{typeof(identity), typeof(identity)}}, Tuple{FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}}, FiniteDifferences.var"#Real_from_vec#20", FiniteDifferences.var"#Real_from_vec#20"}})(v::Vector{Float64})
# huge stacktrace
Urg, this seems like it is something going wrong in FiniteDifferences.jl
I know this because the stacktrace is talking about to_vec.
So you can be sure that it’s not your fault.
I have openned an issue
For now you can work around by using a different vector type, like:
julia> test_rrule(mulsafe, collect(2.0:5.0), collect(1.0:4.0))
Test Summary: | Pass Total
test_rrule: mulsafe on Vector{Float64},Vector{Float64} | 9 9
Test.DefaultTestSet("test_rrule: mulsafe on Vector{Float64},Vector{Float64}", Any[], 9, false, false)
Anyway, as you can see CRTU is not without it’s rough edges.
It is really hard to make a testing library for AD that meets all 3 of:
don’t let anything wrong through (I think we are doing well at that)
lets everything correct throguh (I think we are doing OK at that, but still some work as you can see)
Gives clear an helpful messages (we are trying, but mostly failing I will admit. This is the improved version)
Hopefully me working though how I would debug these failures helps you to.
This needs to be dx * δ, then I think it should work (at least for positive x).
One possibility here would be for the tester to try evaluating the function at a non-integer point. size(x, nextfloat(float(d))) will be an error, from which you infer that d::Int represents a non-perturbable parameter.
(But not sure it’s worth the effort & complication.)