In the spirit of https://arxiv.org/pdf/2105.15183.pdf, I have implemented a general differentiable implicit function in https://github.com/JuliaNonconvex/NonconvexUtils.jl. Feel free to take it for a spin and/or contribute more specialised applications of this feature for specific problem classes. You can find an example in the README.
ANN: Differentiable implicit functions in Julia (optimisation, nonlinear solves and fixed point iterations)
While I think it’s cool it can be done generically, ultimately it comes down to algorithms and it’s hard for me to see how to make this interface fully swappable and optimized without handling each of the cases. For example, NonlinearSolve.jl’s adjoints have all sorts of things going on with heuristics for swapping AD systems and whether or not to use Jacobian-free Newton Krylov for the backpass or whether to build the Jacobian (which can be dependent on size, AD, etc.). And PDEs as well have a whole slew of things going on which dramatically impact performance by orders of magnitude. The heuristics need to be tailored to the different implicit function types and solvers as well, so I cannot see this as a long-term solution.
But as a fallback for situations not already covered (i.e. not fixed point iteration, nonlinear solves, or diffeqs but now the things beyond that like convex optimization until something like DiffOpt.jl is ready) this seems like a good stop-gap measure. Though unlike the Jax position paper, I think we should accept that it would be a clean but bad idea to build the entire ecosystem around this reduction.
Ya I forgot to mention. You can do this with the
matrixfree kwarg in NonconvexUtils. I should probably document it. You can also specify a custom linear system solver.
Though unlike the Jax position paper, I think we should accept that it would be a clean but bad idea to build the entire ecosystem around this reduction.
I disagree with this. DiffOpt is built using a similar approach as
ImplicitFunction but the optimality conditions and forward algorithm are customised for a specific problem class. The
ImplicitFunction struct can therefore save the developers of DiffOpt the effort of manually deriving the Jacobians or pullbacks of the optimality conditions. You can still customise the optimality conditions and optimisation solver for your class of problems but you can stop there, go get a cup of tea and let AD do the rest (when it doesn’t error!).
Of course defining a rule for a specific implicit function is always an option but that can now be treated like defining a rule for any Julia function in the presence of AD. As in you only need to do it if AD fails or is too slow for your problem.