Custom bijector in TuringLang Bijectors

I am new to Julia and I am trying to use the Bijectors.jl package to define a tanh bijector. Looking through the doc and the code, I came up with different versions of the following

struct Tanh <: Bijectors.Bijector end

(b::Tanh)(x::Real) = tanh(x)
(b::Tanh)(x::AbstractArray) = @. tanh(x)

(ib::Bijectors.Inverse{<: Tanh})(y::Real) = atanh(y)
(ib::Bijectors.Inverse{<: Tanh})(y::AbstractArray) = @. atanh(y)

function with_logabsdet_jacobian(b::Tanh, x::Real)
    transformed = tanh(x)
    log_det_jacobian = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
    return (result=transformed, logabsdetjac=log_det_jacobian)
end

function with_logabsdet_jacobian(b::Tanh, x::AbstractArray)
    transformed = @. tanh(x)
    log_det_jacobian = 2.0 .* (log(2.0) .- x .- softplus.(-2.0 .* x))
    return (result=transformed, logabsdetjac=log_det_jacobian)
end

Bijectors.logabsdetjac(b::Tanh, x) = last(with_logabsdet_jacobian(b, x))

The forward transformation appears to be working, however combining the bijector with a transformed distribution runs into a stack overflow when I call the logpdf function

using Random
using Distributions
using LinearAlgebra
import Bijectors
using LogExpFunctions: softplus

dist = Distributions.MvNormal(zeros(1), LinearAlgebra.I)
td = Bijectors.transformed(dist, Tanh())

y = Random.rand(td)
lp = Distributions.logpdf(td, y)
ERROR: StackOverflowError:
Stacktrace:
     [1] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213
     [2] transform(t::Bijectors.Inverse{Tanh}, x::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:92
--- the last 2 lines are repeated 39990 more times ---
 [79983] with_logabsdet_jacobian(ib::Bijectors.Inverse{Tanh}, y::Vector{Float64})
       @ Bijectors ~/.julia/packages/Bijectors/QhObI/src/interface.jl:213

Any ideas what what I’ve gotten wrong here?

The following definition appears to work although I don’t really know why. Would still appreciate an explanation.

struct Tanh <: Bijectors.Bijector end

# (b::Tanh)(x::Real) = tanh(x)
# (b::Tanh)(x::AbstractArray) = @. tanh(x)

# (ib::Bijectors.Inverse{<: Tanh})(y::Real) = atanh(y)
# (ib::Bijectors.Inverse{<: Tanh})(y::AbstractArray) = @. atanh(y)

Bijectors.transform(b::Tanh, x) = tanh(x)
Bijectors.transform(b::Tanh, x::AbstractArray) = @. tanh(x)

Bijectors.transform(ib::Bijectors.Inverse{<:Tanh}, y) = atanh(y)
Bijectors.transform(ib::Bijectors.Inverse{<:Tanh}, y::AbstractArray) = @. atanh(y)


function with_logabsdet_jacobian(b::Tanh, x::Real)
    y = tanh(x)
    ldj = 2.0 * (log(2.0) - x - softplus(-2.0 * x))
    return (result=y, logabsdetjac=ldj)
end

function with_logabsdet_jacobian(b::Tanh, x::AbstractArray)
    y = @. tanh(x)
    ldj = @. 2.0 * (log(2.0) - x - softplus(-2.0 * x))
    return (result=y, logabsdetjac=only(ldj))
end

@torfjelde