Checkpointing and adjoints for custom kernels

I’ve written a PDE finite-difference solver that uses KernelAbstractions.jl to update my stencil each time step.

In my previous life, when I’d write this stuff in C/C++, I’d roll my own adjoints, and then hook them up to an AD package. And if I really needed to, I’d implement some naive checkpointing to alleviate any memory restrictions.

I’ve looked through the various SciML packages, and it seems like there’s a lot of machinery in place to automate the adjoint sensitivity and even checkpointing. But it also seems like I have to be “all in” and use the full SciML ecosystem to discretize the PDE etc.

Am I wrong? Is there a way to use my existing kernel code, and wrap it with some of the SciML packages just to handle the adjoint/backpropagation and the checkpointing? What’s the canonical way people typically address this?

To give some context, my stencil code modifies my data arrays in place, so relying on something using Zygote probably wouldn’t work (although I hope I’m wrong).

Thanks!

1 Like

A potential alternative could be to use Enzyme.jl which has good integration in KernelAbstractions.jl, together with Checkpointing.jl. Setting up that machinery may be more “manual” maybe as such, but could come with some good performance gain and should be GPU compatible as well.

1 Like

Great suggestion, thanks!