Reverse rule in Enzyme for an implicit function

Suppose H is the cdf of a distribution, A, B are real numbers, and p is defined by

p = H(p A + (1 - p) B)

This can be implemented simply in Julia as

using Roots, Distributions

struct Problem{T,S}
    H::T
    A::S
    B::S
end

function condition(problem::Problem, p)
    (; H, A, B) = problem
    cdf(H, p * A + (1 - p) * B) - p
end

solvep(problem::Problem) = find_zero(Base.Fix1(condition, problem), (0.0, 1.0))

However, I want to

  1. AD through solvep using Enzyme.jl, (actually, I am ADing a larger function that calls solvep)

  2. preferably in a way that works via a generic H, so I don’t have to commit to knowing its parameters etc.

My understanding is that I would have to manually call AD on the H(...) expression, but I don’t know how to hook into that. Any suggestions are welcome, even if they are not full solutions.

not yet finished I think but it will be FAQ · ImplicitDifferentiation.jl, works with Zygote.jl and ForwardDiff.jl and you may be able to define the rule with it and build the enzyme one after

As stated in the question, I want to work with Enzyme.jl. Also, cf

1 Like