Zygote much slower than JAX for automatic differentiation of energy

You’re definitely right that we would need to apply this improvement on both sides.
However in some cases a well-written code can be much easier to differentiate for Zygote than a badly written one, so it’s not necessarily zero-sum.

By the way, I edited my code above with an even faster version, yielding x10 speedup on the energy.

Silly me, I forgot to run the actual gradient computation… I now observe the same allocations as you, and my x10 faster energy function is actually… slower to differentiate. Very frustrating indeed.

julia> @benchmark compute_energy_and_gradient($model, $ps, $st, $H, $all_configurations)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min … max):  1.593 s …    1.842 s  ┊ GC (min … max):  0.03% … 10.68%
 Time  (median):     1.827 s               ┊ GC (median):    10.77%
 Time  (mean ± σ):   1.754 s ± 139.300 ms  ┊ GC (mean ± σ):   8.00% ±  6.64%

  █                                                     █  █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁█ ▁
  1.59 s         Histogram: frequency by time         1.84 s <

 Memory estimate: 4.01 GiB, allocs estimate: 82.

julia> @benchmark compute_energy_and_gradient($model_fast, $ps_fast, $st_fast, $H, $all_configurations)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min … max):  1.782 s …   1.919 s  ┊ GC (min … max):  0.16% … 11.39%
 Time  (median):     1.840 s              ┊ GC (median):    11.87%
 Time  (mean ± σ):   1.847 s ± 68.330 ms  ┊ GC (mean ± σ):   8.09% ±  6.77%

  █                       █                               █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.78 s         Histogram: frequency by time        1.92 s <

 Memory estimate: 4.04 GiB, allocs estimate: 324.