Memory issues for large computational graph

I’m working on developing a differentiable CFD solver incorporating the immersed boundary method for solving the Navier-Stokes equations. I plan on using Enzyme.jl .My goal is to compute gradients of a parametric geometry with respect to a custom-defined loss function, which I plan to optimize using gradient descent. However, I’m encountering memory issues due to the large computational graph resulting from multiple iterations (i.e., time steps). Are there any effective strategies to mitigate this problem?

Thank you in advance for your suggestions!

For time stepping adjoints, you normally want to solve this by not differentiating the solver directly but instead define some adjoints over it that optimize a few things. Some of things you can optimize are memory with continuous checkpointing and compute time because some of the computations can be skipped. If you use DifferentialEquations.jl for the time stepping or NonlinearSolve.jl for the nonlinear solving these optimizations will be automatically applied.

Could you suggest some resources, I could use to learn about time-stepping adjoints for PDE?

That’s probably the most comprehensive these days, though it’s missing some of the newer tricks.

is also an overview, but missing some of the newer optimizations. The SciMLSensitivity.jl source code really describes it in full :sweat_smile:

1 Like