What are best practices for debugging when encountering errors with DiffEqFlux problems?
I’m working on creating a differentiable ODE simulator with DiffEqFlux. It’s an amazing and powerful tool! However, debugging taking gradients through ODEs is frequently more difficult than debugging standard Julia code. For example,
- Error messages can be massive. The names of single types can be 2000+ lines long in a stacktrace.
- Debugger.jl is difficult to use in this case. It causes Julia to interpret the DiffEqFlux + Zygote code, which results in a dramatic runtime performance drop.
- When using Infiltrator.jl or printouts during the differentiation process, the intermediate types used for differentiation (e.g. dual numbers, adjoints, etc) can be tricky to map back to the original parameters of the simulation.
I would love to know any tips or tricks to isolate and identify bugs when differentiating through DiffEqFlux ODEs / DAEs / etc. Here’s what I’ve tried so far with some success:
- Double check that the model is actually differentiable.
- Check that I can take the gradients of individual sub-functions used within my ODE dynamics function.
- Comment out chunks of the dynamics function, and incrementally uncomment them until the error returns.
Any additional suggestions?
(This isn’t intended to be negative criticism directed specifically at DiffEqFlux–it’s a fantastic tool and I’m excited to use it going forward! As a fairly new user, I’m simply trying to speed up my debugging process.)