Hello Guys,
I’m trying to build a probabilistic model in Turing.jl, and it is supposed that when you change the AD backend to :reversediff you get computational improvements. I have a forward operator inside my probabilistic model that had a tanh function. When running the code, ReverseDiff, would fail at this line.
I replaced the Base.tanh definition to a custom definition using the Pade approximation:
tanh_pade(x) = x * (2027025 + 270270 * x^2 + 6930 * x^4 + 36 * x^6) / (2027025 + 945945 * x^2 + 51975 * x^4 + 630 * x^6 + x^8)
and it would fail again with the following error:
ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Int64 is not defined. Please use `ReverseDiff.value` instead.
Stacktrace:
[1] convert(#unused#::Type{Int64}, t::ReverseDiff.TrackedReal{Float64, Float64, Nothing})
@ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/tracked.jl:261
[2] _cpow(z::Complex{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, p::Complex{ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
@ Base ./complex.jl:791
[3] ^
@ ./complex.jl:859 [inlined]
[4] ^
@ ./promotion.jl:444 [inlined]
[5] ^
@ ./complex.jl:864 [inlined]
[6] literal_pow
@ ./intfuncs.jl:338 [inlined]
[7] tanh_pade(x::Complex{ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
@ Main ~/AQ_research/Bayes/DRIVER_BAYES.jl:38
[8] MT_mod_wait_bayes_cells(depth::Vector{Float64}, rho::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, fa::Vector{Float64})
@ Main ~/AQ_research/Bayes/DRIVER_BAYES.jl:17
and redefining the custom tanh to:
function tanh_pade(x)
y= x * (2027025 + 270270 * x*x + 6930 * x*x*x*x + 36 * x*x*x*x*x*x) / (2027025 + 945945 * x*x + 51975 * x*x*x*x + 630 * x*x*x*x*x*x + x*x*x*x*x*x)
return y
end
would work but, obviously It made the code incredibly slow. At this point I’m not able to see the advantages of using ReverseDiff.jl. By using ForwardDiff.jl, I can use Base.tanh and the code is at least 3 times faster, which is the opposite to what is expected when using Turing.jl.
To summarize, the last custom function works, but slower than just using ForwardDiff. Anyone has any idea of what could be happening?
Thanks