Custom rule for an implicit function in Enzyme

Consider an implicit function y = f(x) defined via

g(x, y)

where x, y \in \mathbb{R}^n. The user supplies f!(y, x) and g!(r, x, y) as callables, which are wrapped in a

struct SquareImplicitFunction{F,G}
    f!::F
    g!::G
end

function (ℐ::SquareImplicitFunction)(y, x)
    ℐ.f!(y, x)
    nothing
end

The user should be able to AD through ℐ(y, x) calls in Enzyme, using forward and reverse mode.

I have coded up an MWE package here, which kind of works (simple tests run), but I have some open questions about it. Note that the questions all pertain to the code, which I am not copy-pasting here, as it is >100 LOC (with plenty of docstrings), so if you want to help please look at it.

Q1 I figured out AD while making the callable Const(g!), but I am wondering if there is a way to have AD work if it was somehow parameterized. Eg for the trivial example

struct G!{T<:Real}
    p::T
end
(g!::G!)(x, y) = x .+ g!.p .- y
make_f!(g!::G!) = y .= g!.p .+ x

it would be nice to be able to perturb p and have AD work. But, frankly, I am conceptually lost on how to do this.

Q2 I wonder if there is a way to cache the Jacobian (for the same x and y), in forward and reverse mode. Should I just save it on the tape in reverse mode, would that be reused for various cotangents? What about forward mode?

Q3 With the function signatures above, I am not sure what combinations of Const and Duplicated I should plan for. Can it ever happen that, of x and y, one is Const and the other is Duplicated?

Any other comments are welcome too. As it may be apparent from the code I have a limited understanding of AD and Enzyme.jl. And yes, I am aware of ImplicitDifferentiation.jl, and plan to contribute to it, but I want to understand things first in a simplified context.

1 Like

Should I just save it on the tape in reverse mode

Yeah you can do that, but it’s always the question if it is worth to calculate the entire Jacobian or if JvP are sufficient. You can’t cache in forward mode (except for batch evaluation)

But, frankly, I am conceptually lost on how to do this.

On the Enzyme, level this just means that you need to pass Duplicated(g), this providing a storage location for the tangent/shadow value.

Can it ever happen that, of x and y, one is Const and the other is Duplicated?

I think you could have Const(y) and Duplicated(x) (perhaps even vice-versa) that just means that y has been marked inactive and x is active.

At some point I did try to add rules to ImplicitDifferentiation Support Enzyme through EnzymeRules by vchuravy · Pull Request #186 · JuliaDecisionFocusedLearning/ImplicitDifferentiation.jl · GitHub but I forgotten how far I got.

2 Likes

I’d love for you or someone else to pick up that PR! The most urgent thing would be to add tests, and then we can see what works and what doesn’t :slight_smile:

1 Like

I would be happy to do it but I need to learn a lot about AD in general and Enzyme specifically before. I am doing that in my own repo for now.

1 Like