I’ve written a PDE finite-difference solver that uses KernelAbstractions.jl to update my stencil each time step.
In my previous life, when I’d write this stuff in C/C++, I’d roll my own adjoints, and then hook them up to an AD package. And if I really needed to, I’d implement some naive checkpointing to alleviate any memory restrictions.
I’ve looked through the various SciML packages, and it seems like there’s a lot of machinery in place to automate the adjoint sensitivity and even checkpointing. But it also seems like I have to be “all in” and use the full SciML ecosystem to discretize the PDE etc.
Am I wrong? Is there a way to use my existing kernel code, and wrap it with some of the SciML packages just to handle the adjoint/backpropagation and the checkpointing? What’s the canonical way people typically address this?
To give some context, my stencil code modifies my data arrays in place, so relying on something using Zygote probably wouldn’t work (although I hope I’m wrong).
Thanks!